ちりもつもればミルキーウェイ

好奇心に可処分時間が奪われる

152行でprotobufの簡易的なunmarshalerかいた

はじめに

protobufのエンコーディングwireという名前で以下に仕様がある

developers.google.com

最近仕様を読んだりproto定義から生成される成果物を読んだりいろいろしていて、Unmarshalerかける気がしたので書いてみた。

すべての仕様に準拠するものは大変なので、現状はネストしない構造かつtypeがVarint(のenum以外)、Length-delimited(のstring, bytesのみ)だけサポートしている。

以下のようなproto定義で生成されるメッセージがUnmarshalできます

message Test {
    int32  Int32 = 1;
    int64  Int64 = 2;
    bool   Boolean = 3;
    sint32 Sint32 = 4;
    sint64 Sint64 = 5;
    string Str   = 6;
    bytes  Bytes = 7;
}

仕様を網羅しなければ意外と150行くらいで書けちゃってお〜なるほどとなった。まあ本当は仕様を網羅しないといけないんですが。

ざっくり動くものはこちら(絶賛実装中なので最新のmainは152行超えてます。この記事の目的は簡易的な実装は案外簡単にできると示すことなので、当時のコミットを貼っておきます)

github.com

それではいくぞー

wire形式のバイナリの読み方

wire形式のバイナリについてkey, valueにわけて読み方をまとめる。
ちなみにkeyのことをtagとよぶっぽい

valueはさらに型によって読み方が違うので、それぞれの型の読み方をまとめる

tagのよみかた

wire形式のバイナリは、いくつかのkey:valueが仕様に則ってエンコーディングされています。

で、このkeyはドキュメントでは tag と呼ばれています。
タグの下位3bitはそのフィールドの型を、それより上位のbitはフィールド番号を示しています。

フィールドの型については仕様書にきっかり明記されており、wireにおける型は以下の6種類です

Type Meaning Used For
0 Varint int32, int64, uint32, uint64, sint32, sint64, bool, enum
1 64-bit fixed64, sfixed64, double
2 Length-delimited string, bytes, embedded messages, packed repeated fields
3 Start group groups (deprecated)
4 End group groups (deprecated)
5 32-bit fixed32, sfixed32, float

このうち3と4は下位互換のために残っているだけなので今回は無視します
(wireのいいところのひとつはやたらと下位互換を考慮しているところだからあんまり良くないけどあくまでエンコーディングの理解が目的なので)

フィールドについては最大のサイズが決まっており、

https://developers.google.com/protocol-buffers/docs/proto3#assigning_field_numbers

The smallest field number you can specify is 1, and the largest is 229 - 1, or 536,870,911.

とあるように、フィールド番号は可変長で最大29bitです。 それぞれのbyteのMSB(最上位bit)が終端かどうかの判断に使わることを考えると、tagは最大5byteかな( (1+7)*4 + (1+1+3) 3bit余るのが気になるけど通常ここまで大きいfield_number使うこと無いだろうしあんまり問題無いんだろうな。わりと気味の悪い半端さなので僕の計算が間違えてるかも )

Goの実装 みると最大64bitまでパースしそうでなんでなん?とおもっている。typeとあわせて32bitまでじゃないんすか? -> とおもってたら ここ で考慮されてたので問題なかった

たとえば

0000 1000

みたいなバイトが渡されたときこういう感じで判断される

tagはこんな感じで判断される

みたとおり、field_numberが15(4bitの最大値)までだったらtagが1byteですむので小さく済むよということです。ドキュメントにも言及があります
https://developers.google.com/protocol-buffers/docs/proto3#assigning_field_numbers

Note that field numbers in the range 1 through 15 take one byte to encode, including the field number and the field's type (you can find out more about this in Protocol Buffer Encoding). Field numbers in the range 16 through 2047 take two bytes. So you should reserve the numbers 1 through 15 for very frequently occurring message elements.

ここまでよむとわかると思いますが、wire形式のエンコーディングでは メッセージのフィールド名は一切考慮されません。

jsonなどのキーがstringのエンコーディングだとバイト数も多いしMarshaler/Unmarshaler間で一致したフィールド名で操作しないと正しく結果が読み取れません。

wireだとキーもfield_numberが15を超えるまでは1byteなので大抵の場合はstringキーよりもスリムだし、field_numberがあってればserverでは Name string と思っていてclientは Identifier string と思っていても全く問題ないわけです。

また、同じwire typeであればプログラム側の型を変えることも可能です。たとえばserverではint64だとおもっているけど、clientはint32だと思っている!なんてときでもwire的にはちゃんと値をパースできます
(とはいえオーバーフローするとserver側の本来表現したかった値と違うものになってしまうのでそんな簡単な話ではないが)

というわけで、この例ではMSBが0なので終端byteであり、下位3bitはtypeを表していてそれが0なので Varint(可変長バイト列) であり、のこりはfield_numberで1ということになる
tagをパースするコードは以下のようにかける

func Unmarshal(b []byte, v interface{}) error {
    for len(b) > 0 { {
        // タグは可変長バイト列形式
        tag, n, err := readVarint(b)
        if err != nil {
            return err
        }
        b = b[n:]
        // 仕様でtype, field_number合わせて32bitまでなので超えてたらエラー
        if tag > math.MaxUint32 {
            return OverflowErr
        }
        // 下位3bitはtype, それ以外はfield_number
        fieldNum := tag >> 3
        tp := tag & 0x7
        // 読み取った情報をlogにはく
        fmt.Printf("readed byte size: %d, field_number: %d, type: %d", n, fieldNum, tp)
    }
    return nil
}

// readVarint は可変長バイト列の読み取り処理
func readVarint(b []byte) (v uint64, n int, err error) {
    // little endian で読み取っていく
    for shift := uint(0); ; shift += 7 {
        // 64bitこえたらoverflow
        if shift >= 64 {
            return 0, 0, OverflowErr
        }
        // 対象のbyteの下位7bitを読み取ってvにつめていく
        target := b[n]
        n++
        v |= uint64(target&0x7F) << shift
        // 最上位bitが0だったら終端なのでよみとり終了
        if target < 0x80 {
            break
        }
    }
    return v, n, nil
}

さっきの例の

0000 1000

みたいなバイトを食わせる以下のような処理を実行すると

func main() {
    b := []byte{0b00001000}
    protowire.Unmarshal(b, nil)
}
$ go run main.go
readed byte size: 1, field_number: 1, type: 0

となり、問題なく読めてそうです。やったね!

fieldの読み方

とりあえずfieldを読み取る前に、値をbindしたいstructなりの各フィールドのwire typeとfield numberを知ってる必要がある
これはstructにタグつければ良くて、ここでは適当に以下のようなタグ形式とする

type User struct{
    Age  int64  `protowire:"1,0"` // field_number = 1, type = 0(Varint)
    Name string `protowire:"2,2"` // field_number = 2, type = 2(Length-delimited)
}

これはreflectつかって適当によめばいい。
ちゃんとやるならmapあたりにも対応したほうがいいし、tagの内容が順不同でもOKなつくりにしたほうがいいけど今回は別にそこまでちゃんとやらんでもよいので適当に書く
一応typeは3bit値で最大7まであるので、7までは許容しておく(wireの仕様には7のときの定義はないけど一応)

type structTag struct {
    tp             uint8
    structFieldNum int
}

// parseTags はstructに振ってあるprotowireタグを読み取ってmapに変換する
// mapのキーはfield_number
func parseTags(v interface{}) (map[uint32]structTag, error) {
    if reflect.ValueOf(v).Kind() != reflect.Ptr {
        return nil, errors.New("struct must be a pointer")
    }
    rt := reflect.Indirect(reflect.ValueOf(v)).Type()
    fieldSize := rt.NumField()
    tags := make(map[uint32]structTag, fieldSize)
    for i := 0; i < fieldSize; i++ {
        f := rt.Field(i)
        t := strings.Split(f.Tag.Get("protowire"), ",")
        fieldNum, err := strconv.Atoi(t[0])
        if err != nil {
            return nil, err
        }
        if fieldNum > 1<<29-1 {
            return nil, errors.New("invalid protowire structTag, largest field_number is 536,870,911")
        }
        tp, err := strconv.Atoi(t[1])
        if tp > 7 {
            return nil, errors.New("invalid protowire structTag, largest type is 7")
        }
        tags[uint32(fieldNum)] = structTag{tp: uint8(tp), structFieldNum: i}
    }
    return tags, nil
}

でunmarshalの頭でinterface{}からタグをよみとっておけばよい

func Unmarshal(b []byte, v interface{}) error {
+   sts, err := parseTags(v)
+   if err != nil {
+       return fmt.Errorf("failed to parse structTag from input interface{}: %w", err)
+   }
    〜略〜
}

Varintの場合

これは簡単で、もうすでにある readVarint の処理がそのまま使えてしまう

で、カンのいいかたはお気づきだと思うんですがtagを読み取るときと違って下位3bitは別の値!とかないので1byteあたり最大127までの値を扱えます。
可変長な部分はすべてこの readVarint を使い回すので、つまりはtag以外の可変長バイト列は値が 127 以下だとコスパいいってことですね。(たとえばstirngは文字長をVarint形式でバイナリに詰めるので、127文字以下だと文字長の情報が1byteですむ。)

さっきまでのunmarshal処理でtagのtypeとfield_numberとれてるので、 readVarint の結果をstructの該当fieldにbindすればよいのでこんな処理を足す

func Unmarshal(b []byte, v interface{}) error {
    sts, err := parseTags(v)
    if err != nil {
        return fmt.Errorf("failed to parse structTag from input interface{}: %w", err)
    }
    for len(b) > 0 {
        〜略〜
-      // 読み取った情報をlogにはく
-      fmt.Printf("readed byte size: %d, field_number: %d, type: %d", n, fieldNum, tp)
+       st := sts[fieldNum]
+       if st.tp != tp {
+           return fmt.Errorf("wrong type, structTag type: %d, wire type: %d", st.tp, tp)
+       }
+       switch tp {
+       case 0:
+           f, n, err := readVarint(b)
+           if err != nil {
+               return fmt.Errorf("failed to read varint field: %w", err)
+           }
+           b = b[n:]
+           target := reflect.ValueOf(v).Elem().Field(st.structFieldNum)
+           switch target.Interface().(type) {
+           case int64, int32, int16, int8, int:
+               target.SetInt(int64(f))
+           case uint64, uint32, uint16, uint8, uint:
+               target.SetUint(f)
+           case bool:
+               target.SetBool(f&1 == 1)
+           default:
+               return fmt.Errorf("unsupported type of varint: %s", target.Type().String())
+           }
+           b = b[n:]
+       default:
+           return fmt.Errorf("unsupported type: %d, err: %w", tp, UnknownType)
+       }
    }
    return nil
}

これでvarintは読めるようになった

varintだけならprotobufがよめるはずなので https://github.com/golang/protobuf を使って適当に吐いたバイナリを読み取るテストしてみる

proto定義は以下

syntax = "proto3";
package ex;

message TestVarint {
    int32 Int32 = 1;
    int64 Int64 = 2;
    bool  Boolean = 3;
}

Goの生成コードは以下

func main() {
    t := &ex.TestVarint{
        Int32:   12345,
        Int64:   67890,
        Boolean: true,
    }
    b, _ := proto.Marshal(t)
    fmt.Printf("%x\n", b)
}

出力は以下

$ go run main.go
08b96010b292041801

このバイナリを使ったテストをかく

func TestUnmarshal(t *testing.T) {
    varintTestBin, _ := hex.DecodeString("08b96010b292041801")
    type varintTest struct {
        Int32   int32 `protowire:"1,0"`
        Int64   int64 `protowire:"2,0"`
        Boolean bool  `protowire:"3,0"`
    }

    type args struct {
        b []byte
        v interface{}
    }
    tests := []struct {
        name    string
        args    args
        want    *varintTest
        wantErr bool
    }{
        {
            name: "Varintの検証バイナリ",
            args: args{
                b: varintTestBin,
                v: &varintTest{},
            },
            want: &varintTest{
                Int32:   12345,
                Int64:   67890,
                Boolean: true,
            },
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            if err := Unmarshal(tt.args.b, tt.args.v); (err != nil) != tt.wantErr {
                t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr)
            }
            if !reflect.DeepEqual(tt.args.v, tt.want) {
                t.Errorf("Unmarshal() got = %v, want %v", tt.args.b, tt.want)
            }
        })
    }
}

実行!

$ go test -v ./...

=== RUN   TestUnmarshal
=== RUN   TestUnmarshal/Varintの検証バイナリ
--- PASS: TestUnmarshal (0.00s)
    --- PASS: TestUnmarshal/Varintの検証バイナリ (0.00s)

うごいた!わーい

Varintのなかの曲者、sint32とsint64

こいつらは曲者で、zigzag encodingされて値が飛んできます

zigzag encodingについては前にここに書いたので詳細は割愛

convto.hatenablog.com

ようするに負数が登場しうる値ではsint32とsint64を使うとバイト数の削減になるのでつかってくれよな!ということでwireではsint32/sint64が指定されるとzigzag encodingで値が飛んでくるようになっている

これはzigzagだったらそのように扱うだけなので、すぐできる

type structTag struct {
    tp             uint8
    structFieldNum int
+   zigzag         bool
}
func parseTags(v interface{}) (map[uint32]structTag, error) {
    if reflect.ValueOf(v).Kind() != reflect.Ptr {
        return nil, errors.New("struct must be a pointer")
    }
    rt := reflect.Indirect(reflect.ValueOf(v)).Type()
    fieldSize := rt.NumField()
    tags := make(map[uint32]structTag, fieldSize)
    for i := 0; i < fieldSize; i++ {
        〜略〜
-      tags[uint32(fieldNum)] = structTag{tp: uint8(tp), structFieldNum: i}
+       zigzag := false
+       if len(t) == 3 && t[2] == "zigzag" {
+           zigzag = true
+       }
+       tags[uint32(fieldNum)] = structTag{tp: uint8(tp), structFieldNum: i, zigzag: zigzag}
    }
    return tags, nil
}
func Unmarshal(b []byte, v interface{}) error {
    sts, err := parseTags(v)
    if err != nil {
        return fmt.Errorf("failed to parse structTag from input interface{}: %w", err)
    }

    l := len(b)
    for l > 0 {
        〜略〜
        switch tp {
        case 0:
            f, n, err := readVarint(b)
            if err != nil {
                return fmt.Errorf("failed to read varint field: %w", err)
            }
            b = b[n:]
            target := reflect.ValueOf(v).Elem().Field(st.structFieldNum)
            switch target.Interface().(type) {
-          case int64, int32, int16, int8, int:
-              target.SetInt(int64(f))
+           case int64:
+               i := int64(f)
+               if st.zigzag {
+                   i = int64((uint64(i) >> 1) ^ uint64(((i&1)<<63)>>63))
+               }
+               target.SetInt(i)
+           case int32:
+               i := int32(f)
+               if st.zigzag {
+                   i = int32((uint32(i) >> 1) ^ uint32(((i&1)<<31)>>31))
+               }
+               target.SetInt(int64(i))
+           case int16, int8, int:
                target.SetInt(int64(f))
            〜略〜
        }
    }
    return nil
}

ではさきほどのようにバイナリを吐いてみる

protoに以下を追加してバイナリを吐きなおす

message TestVarintZigzag {
    sint32 Sint32 = 1;
    sint64 Sint64 = 2;
}
func main() {
    t := &ex.TestVarintZigzag{
        Sint32:   -12345,
        Sint64:   -67890,
    }
    b, _ := proto.Marshal(t)
    fmt.Printf("%x\n", b)
}
$ go run main.go              
08f1c00110e3a408

テストケース追加

func TestUnmarshal(t *testing.T) {
    〜略〜
+   testVarintZigzagBin, _ := hex.DecodeString("08f1c00110e3a408")
+   type testVarintZigzag struct {
+       Sint32 int32 `protowire:"1,0,zigzag"`
+       Sint64 int64 `protowire:"2,0,zigzag"`
+   }
    〜略〜
    }{
        〜略〜
+       {
+           name: "Varintでzigzagの検証バイナリ",
+           args: args{
+               b: testVarintZigzagBin,
+               v: &testVarintZigzag{},
+           },
+           want: &testVarintZigzag{
+               Sint32: -12345,
+               Sint64: -67890,
+           },
+       },

テスト実行!

$ go test -v ./...
=== RUN   TestUnmarshal
=== RUN   TestUnmarshal/Varintの検証バイナリ
=== RUN   TestUnmarshal/Varintでzigzagの検証バイナリ
--- PASS: TestUnmarshal (0.00s)
    --- PASS: TestUnmarshal/Varintの検証バイナリ (0.00s)
    --- PASS: TestUnmarshal/Varintでzigzagの検証バイナリ (0.00s)
PASS
ok      github.com/convto/protowire     0.878s

ヨシ!

Length-delimited

Length-delimitedの場合はvalueの先頭にバイト長が記載されているので、それをよみこめばよい。
そしてそのバイト長はいつもの readVarint でとれます。わーい楽

unmarshalに以下を追加

func Unmarshal(b []byte, v interface{}) error {
    〜略〜
        switch tp {
        case 0:
            〜略〜
+       case 2:
+           byteLen, n, err := readVarint(b)
+           if err != nil {
+               return fmt.Errorf("failed to read varint field: %w", err)
+           }
+           b = b[n:]
+           val := b[:byteLen]
+           b = b[int(byteLen):]
+           target := reflect.ValueOf(v).Elem().Field(st.structFieldNum)
+           switch target.Interface().(type) {
+           case string:
+               target.SetString(string(val))
+           case []byte:
+               target.SetBytes(val)
+           default:
+               return fmt.Errorf("unsupported type of length-delimited: %s", target.Type().String())
+           }
+           
        default:
            return fmt.Errorf("unsupported type: %d, err: %w", tp, UnknownType)
        }
    }
    return nil
}

以上。テストするぞ。いつもの流れでprotoいじってバイナリ吐いてみるぞ

message TestLengthDelimited {
    string Str   = 1;
    bytes  Bytes = 2;
}
func main() {
    t := &ex.TestLengthDelimited{
        Str:   "これはてすとだよ",
        Bytes: []byte{0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0xAA},
    }
    b, _ := proto.Marshal(t)
    fmt.Printf("%x\n", b)
}
$ go run main.go
0a18e38193e3828ce381afe381a6e38199e381a8e381a0e382881206ffeeddccbbaa
func TestUnmarshal(t *testing.T) {
    〜略〜
+   testLengthDelimitedBin, _ := hex.DecodeString("0a18e38193e3828ce381afe381a6e38199e381a8e381a0e382881206ffeeddccbbaa")
+   type testLengthDelimited struct {
+       Str   string `protowire:"1,2"`
+       Bytes []byte `protowire:"2,2"`
+   }
    〜略〜
    }{
        〜略〜
+       {
+           name: "Length-delimitedの検証バイナリ",
+           args: args{
+               b: testLengthDelimitedBin,
+               v: &testLengthDelimited{},
+           },
+           want: &testLengthDelimited{
+               Str:   "これはてすとだよ",
+               Bytes: []byte{0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0xAA},
+           },
+       },
    }
    〜略〜
}

いざテスト!

$ go test -v ./...
=== RUN   TestUnmarshal
=== RUN   TestUnmarshal/Varintの検証バイナリ
=== RUN   TestUnmarshal/Varintでzigzagの検証バイナリ
=== RUN   TestUnmarshal/Length-delimitedの検証バイナリ
--- PASS: TestUnmarshal (0.00s)
    --- PASS: TestUnmarshal/Varintの検証バイナリ (0.00s)
    --- PASS: TestUnmarshal/Varintでzigzagの検証バイナリ (0.00s)
    --- PASS: TestUnmarshal/Length-delimitedの検証バイナリ (0.00s)
PASS
ok      github.com/convto/protowire     0.0984s

やったね!

まとめ

  • wireはtagとvalueにわかれている
  • tagは下位3bitがtype, それ以外は可変長のfield_number
  • valueはtypeによってパースのやりかたがことなる
  • 仕様はまだあって、今の実装だとtypeが足りなかったりrepeatedとかの packed repeated field 対応とかが必要

感想

wireはencodingとして優れていると感じた。
byte効率もすぐれているし、client - server でネゴらなくてもスキーマ変えれる余地があったり、今回は触れてないけどある値があとからrepeatableになっても互換性を維持できる。

あとprotobufのMarshal/Unmarshalする構造だけproto定義から生成する(without gRPC server)みたいなこともできて、jsonでネゴるより楽な気がした。

protobufはgRPCとセットの文脈で語られることが多いけど、普通に HTTP over Protobuf とかやってみるとめちゃめちゃ嬉しいんじゃないかな。protobufだけでも有用。

あとは意外と普通にprotobufのunmarshalerはかけそうだった!
まだまだ仕様の内容は実装しきれてないので暇なとき実装すすめようかな

コード生成もサポートできると(実際使うときは推奨されてる実装を使うのでとくにやくにはたたないけど)気持ち的に嬉しいのでそれもやってみたい

バイナリのUnmarshalみたいなの自体あまりやったことないので、仕様見るのも必要な前提知識がなくてしらべるところから!とかで大変だったけどwireについては8割がた理解したので良かった。