ちりもつもればミルキーウェイ

好奇心に可処分時間が奪われる

protoc-gen-gogoの成果物がかなり読みやすい件

3行まとめ

protoc-gen-gogoの成果物みてみたら、
ベタッとbytesをunmarshalするコードが書いてあって、
仕様と一緒に見たらかなり読みやすかった

きっかけ

ectdのraftコードを読んでたらたまたまprotoc-gen-gogoで生成されたコードを発見。

github.com

ちょうどprotobufの仕様とか読んだりしてたから気になってみてみた。
素朴にbytesを読んでるだけで抽象化とか特にしてないから前提知識なくても読めて、読みやすかった。

(以前 Goの実装 を読もうとしたこともあるけど、jsonやtextやwireにうまいこと変換できるような抽象化がなされていて、wireの実装だけ詳細が知りたかったのでちょっと困った)

生成

こういうproto定義を

syntax = "proto3";
package example;

message Example {
    int32 standardInt32 = 1;
    sint32 signedInt32 = 2;
    string str = 3;
}

こうする

$ protoc --gofast_out=. ex.proto

生成物の確認

かなり素朴でわかりやすい。 基本的にbyteをゴニョゴニョしてるだけなので仕様書と照らし合わせれば読める

構造体

type Example struct {
    StandardInt32        int32    `protobuf:"varint,1,opt,name=standardInt32,proto3" json:"standardInt32,omitempty"`
    SignedInt32          int32    `protobuf:"zigzag32,2,opt,name=signedInt32,proto3" json:"signedInt32,omitempty"`
    Str                  string   `protobuf:"bytes,3,opt,name=str,proto3" json:"str,omitempty"`
    XXX_NoUnkeyedLiteral struct{} `json:"-"`
    XXX_unrecognized     []byte   `json:"-"`
    XXX_sizecache        int32    `json:"-"`
}
  • wire用のタグがふられる
  • signed int はzigzag encoding使うから区別のためにzigzagってタグがふってあるぽい
  • optionalとかもふってある

marshal

func (m *Example) MarshalToSizedBuffer(dAtA []byte) (int, error) {
    i := len(dAtA)
    _ = i
    var l int
    _ = l
    if m.XXX_unrecognized != nil {
        i -= len(m.XXX_unrecognized)
        copy(dAtA[i:], m.XXX_unrecognized)
    }
    if len(m.Str) > 0 {
        i -= len(m.Str)
        copy(dAtA[i:], m.Str)
        i = encodeVarintExample(dAtA, i, uint64(len(m.Str)))
        i--
        dAtA[i] = 0x1a
    }
    if m.SignedInt32 != 0 {
        i = encodeVarintExample(dAtA, i, uint64((uint32(m.SignedInt32)<<1)^uint32((m.SignedInt32>>31))))
        i--
        dAtA[i] = 0x10
    }
    if m.StandardInt32 != 0 {
        i = encodeVarintExample(dAtA, i, uint64(m.StandardInt32))
        i--
        dAtA[i] = 0x8
    }
    return len(dAtA) - i, nil
}

func encodeVarintExample(dAtA []byte, offset int, v uint64) int {
    offset -= sovExample(v)
    base := offset
    for v >= 1<<7 {
        dAtA[offset] = uint8(v&0x7f | 0x80)
        v >>= 7
        offset++
    }
    dAtA[offset] = uint8(v)
    return base
}

func sovExample(x uint64) (n int) {
    return (math_bits.Len64(x|1) + 6) / 7
}
  • MarshalToSizedBuffer は渡されたbyte sliceにmarshalした結果をつめこむ。末尾からencodeしていく
  • 各フィールドのmarshalは encodeVarintExample でやってる
    • vが8バイト以上の値だったら 最上位byteが1と、vの下位7バイトを連結してdestに詰めてvを7bit右シフトする(1byte読みすすめる)
      • varintのwire typeは上位1bitが立ってるとデータが連続することを表す
    • vが7bit以下の値だったらその値を書き込んで終了
      • varintのwire typeは上位1bitが立っていないとデータ終端であることを表す
  • sovExample はxに必要なバイト数をわりだす
    • xの下位1bitを立てる(立ってないと7で割ったときに端数になるから?よくわからん)
    • 6bit足す(タグをつけるので必要。field number + wire type)
    • 7で割る(必要なbyte数を取得)

unmarshal

func (m *Example) Unmarshal(dAtA []byte) error {
    l := len(dAtA)
    iNdEx := 0
    for iNdEx < l {
        preIndex := iNdEx
        var wire uint64
        for shift := uint(0); ; shift += 7 {
            if shift >= 64 {
                return ErrIntOverflowExample
            }
            if iNdEx >= l {
                return io.ErrUnexpectedEOF
            }
            b := dAtA[iNdEx]
            iNdEx++
            wire |= uint64(b&0x7F) << shift
            if b < 0x80 {
                break
            }
        }
        fieldNum := int32(wire >> 3)
        wireType := int(wire & 0x7)
        if wireType == 4 {
            return fmt.Errorf("proto: Example: wiretype end group for non-group")
        }
        if fieldNum <= 0 {
            return fmt.Errorf("proto: Example: illegal tag %d (wire type %d)", fieldNum, wire)
        }
        switch fieldNum {
        case 1:
            if wireType != 0 {
                return fmt.Errorf("proto: wrong wireType = %d for field StandardInt32", wireType)
            }
            m.StandardInt32 = 0
            for shift := uint(0); ; shift += 7 {
                if shift >= 64 {
                    return ErrIntOverflowExample
                }
                if iNdEx >= l {
                    return io.ErrUnexpectedEOF
                }
                b := dAtA[iNdEx]
                iNdEx++
                m.StandardInt32 |= int32(b&0x7F) << shift
                if b < 0x80 {
                    break
                }
            }
        case 2:
            if wireType != 0 {
                return fmt.Errorf("proto: wrong wireType = %d for field SignedInt32", wireType)
            }
            var v int32
            for shift := uint(0); ; shift += 7 {
                if shift >= 64 {
                    return ErrIntOverflowExample
                }
                if iNdEx >= l {
                    return io.ErrUnexpectedEOF
                }
                b := dAtA[iNdEx]
                iNdEx++
                v |= int32(b&0x7F) << shift
                if b < 0x80 {
                    break
                }
            }
            v = int32((uint32(v) >> 1) ^ uint32(((v&1)<<31)>>31))
            m.SignedInt32 = v
        case 3:
            if wireType != 2 {
                return fmt.Errorf("proto: wrong wireType = %d for field Str", wireType)
            }
            var stringLen uint64
            for shift := uint(0); ; shift += 7 {
                if shift >= 64 {
                    return ErrIntOverflowExample
                }
                if iNdEx >= l {
                    return io.ErrUnexpectedEOF
                }
                b := dAtA[iNdEx]
                iNdEx++
                stringLen |= uint64(b&0x7F) << shift
                if b < 0x80 {
                    break
                }
            }
            intStringLen := int(stringLen)
            if intStringLen < 0 {
                return ErrInvalidLengthExample
            }
            postIndex := iNdEx + intStringLen
            if postIndex < 0 {
                return ErrInvalidLengthExample
            }
            if postIndex > l {
                return io.ErrUnexpectedEOF
            }
            m.Str = string(dAtA[iNdEx:postIndex])
            iNdEx = postIndex
        default:
            iNdEx = preIndex
            skippy, err := skipExample(dAtA[iNdEx:])
            if err != nil {
                return err
            }
            if (skippy < 0) || (iNdEx+skippy) < 0 {
                return ErrInvalidLengthExample
            }
            if (iNdEx + skippy) > l {
                return io.ErrUnexpectedEOF
            }
            m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...)
            iNdEx += skippy
        }
    }

    if iNdEx > l {
        return io.ErrUnexpectedEOF
    }
    return nil
}
  • wireタグをよみこみ
    • あるバイトをよみとり下位7bitをwireにたてる
      • この実装だと上位1bitをどんな場合にも捨てるので 0b11111000(field_num: 15, wire_type: 0) みたいなときにfield_numを7だと思い込まないかな?
    • 終端バイトが登場せずに64bit超えたらoverflowでエラー
      • protobufのfield_numは29bitらしいので、wire typeを含めても32bitが最大な気がする。64bitとなっているのはなぜ?実装ミス?
    • index値がdataのsizeを超えてもエラー
    • field_numは上位1bitが立っていたら終端?field_numもvarintみたいな仕様なのかな。ちょっとわからん。あとでバイナリ見るか
    • あるバイトの上位1bitがたっていたら終端byteなのでwire変数にフィールド情報つめるのを終了
  • wireタグの下位3bitはtype、それ以外はfield_numとする
  • field_numのフィールドにwire_typeごとのよみとりかたで値をパースする
    • wire_typeがvarintなら、リトルエンディアンに並び替えて各byteの下位7bitをフィールドに詰めていく
      • 上位1bitは終端判断のbitなので、立ってたら処理を抜ける
      • このbit読み取り処理は完全にタグ読み取り処理と同一。コード生成側で使い回してるなさては
      • 本来32bitで済むけど、最大64bit型もサポートしてるから64としちゃって全部で使いまわしてるのか。なら納得
    • wire_typeがvarintでかつ符号付きだったら、↑で読み取った値をzigzag encodingに直してフィールドに詰める
      • int32((uint32(v) >> 1) ^ uint32(((v&1)<<31)>>31)) の意味
        • 左辺はvから下位1bitを捨てもの。下位1bitは正負を表現するのでこれは絶対値
        • 右辺は (v&1)<<31 のときにint32は2の補数表現なのでvが1だったら31bitシフトすると -2147483648(最上位bitのみがたったint32のマイナスの最大値)0 になる
        • ((v&1)<<31)>>31 のときに31bit右シフトされる。つまり -10 になる
          • int同士のシフト演算は算術シフトになるので最上位bitは変わらないままシフトされる
          • Goの算術シフトと論理シフトの使い分けがよくわからん。仕様みてみるか
        • それをuint32でキャストすると、2の補数表現での-1は符号なしだと最大値になるので、負数だと全bitがたってて、整数だと全bitが寝てる値が取れる
        • 右辺の絶対値と左辺の-1か1のbitをXORとると、2の補数表現に変換された値を取り出せる
    • wire_typeがlength delimitedなら、文字長よんでそのぶんのバイトをつめる
      • いつものリトルエンディアンにして下位7bitを詰めるやつで文字長を読む
      • 文字長がわかったら、いま読んでる位置から文字長だけ読む

これにてUnmarshal完了!

感想

  • めっちゃ読みやすい
  • 仕様読んだときのこういうときどうするんだろう?の疑問がいくつか解決しそう
  • やっぱり仕様/実装どっちも見ることで理解度は大きく向上する
  • まだちょっとわからないところがあるのでもうちょっと読み込みたい
  • zigzag encodingについてちょっとまとめたいなと思った。今の理解は以下なので実装とかして深めたい
    • standard intは2の補数表現で表されるから、負数になると絶対値が小さくても最上位bitが立って使用bitが常に最大になる
    • zigzag encodingは2の補数表現を使わずに下位1bitの有無で正負を判断するため、絶対値が小さいときは使用bitもすくない
    • その代わり正の値nと比較すると、 常に余計に1bitたつ。
    • uintとかならstandard intのほうがお得だよ
    • 2の補数表現と比較すると、そのまま加減算はできないけど値の表現としては絶対値がすくない負数のときはコスパがいいという。一長一短かな
  • bit演算わかりづれえ...uintとintで挙動が違うのがかなり混乱する
    • 2の補数表現を意識しながらbit演算をしなくてもいいよ、というサポートなんだろうか
    • 逆にわかりづらい。他の言語もこうなのかな。調べてみたい