3行まとめ
protoc-gen-gogoの成果物みてみたら、
ベタッとbytesをunmarshalするコードが書いてあって、
仕様と一緒に見たらかなり読みやすかった
きっかけ
ectdのraftコードを読んでたらたまたまprotoc-gen-gogoで生成されたコードを発見。
ちょうどprotobufの仕様とか読んだりしてたから気になってみてみた。
素朴にbytesを読んでるだけで抽象化とか特にしてないから前提知識なくても読めて、読みやすかった。
(以前 Goの実装 を読もうとしたこともあるけど、jsonやtextやwireにうまいこと変換できるような抽象化がなされていて、wireの実装だけ詳細が知りたかったのでちょっと困った)
生成
こういうproto定義を
syntax = "proto3"; package example; message Example { int32 standardInt32 = 1; sint32 signedInt32 = 2; string str = 3; }
こうする
$ protoc --gofast_out=. ex.proto
生成物の確認
かなり素朴でわかりやすい。 基本的にbyteをゴニョゴニョしてるだけなので仕様書と照らし合わせれば読める
構造体
type Example struct { StandardInt32 int32 `protobuf:"varint,1,opt,name=standardInt32,proto3" json:"standardInt32,omitempty"` SignedInt32 int32 `protobuf:"zigzag32,2,opt,name=signedInt32,proto3" json:"signedInt32,omitempty"` Str string `protobuf:"bytes,3,opt,name=str,proto3" json:"str,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` }
- wire用のタグがふられる
- signed int はzigzag encoding使うから区別のためにzigzagってタグがふってあるぽい
- optionalとかもふってある
marshal
func (m *Example) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int _ = l if m.XXX_unrecognized != nil { i -= len(m.XXX_unrecognized) copy(dAtA[i:], m.XXX_unrecognized) } if len(m.Str) > 0 { i -= len(m.Str) copy(dAtA[i:], m.Str) i = encodeVarintExample(dAtA, i, uint64(len(m.Str))) i-- dAtA[i] = 0x1a } if m.SignedInt32 != 0 { i = encodeVarintExample(dAtA, i, uint64((uint32(m.SignedInt32)<<1)^uint32((m.SignedInt32>>31)))) i-- dAtA[i] = 0x10 } if m.StandardInt32 != 0 { i = encodeVarintExample(dAtA, i, uint64(m.StandardInt32)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } func encodeVarintExample(dAtA []byte, offset int, v uint64) int { offset -= sovExample(v) base := offset for v >= 1<<7 { dAtA[offset] = uint8(v&0x7f | 0x80) v >>= 7 offset++ } dAtA[offset] = uint8(v) return base } func sovExample(x uint64) (n int) { return (math_bits.Len64(x|1) + 6) / 7 }
MarshalToSizedBuffer
は渡されたbyte sliceにmarshalした結果をつめこむ。末尾からencodeしていく- 各フィールドのmarshalは
encodeVarintExample
でやってる sovExample
はxに必要なバイト数をわりだす- xの下位1bitを立てる(立ってないと7で割ったときに端数になるから?よくわからん)
- 6bit足す(タグをつけるので必要。field number + wire type)
- 7で割る(必要なbyte数を取得)
unmarshal
func (m *Example) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { preIndex := iNdEx var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowExample } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ wire |= uint64(b&0x7F) << shift if b < 0x80 { break } } fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { return fmt.Errorf("proto: Example: wiretype end group for non-group") } if fieldNum <= 0 { return fmt.Errorf("proto: Example: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field StandardInt32", wireType) } m.StandardInt32 = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowExample } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.StandardInt32 |= int32(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field SignedInt32", wireType) } var v int32 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowExample } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ v |= int32(b&0x7F) << shift if b < 0x80 { break } } v = int32((uint32(v) >> 1) ^ uint32(((v&1)<<31)>>31)) m.SignedInt32 = v case 3: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Str", wireType) } var stringLen uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowExample } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ stringLen |= uint64(b&0x7F) << shift if b < 0x80 { break } } intStringLen := int(stringLen) if intStringLen < 0 { return ErrInvalidLengthExample } postIndex := iNdEx + intStringLen if postIndex < 0 { return ErrInvalidLengthExample } if postIndex > l { return io.ErrUnexpectedEOF } m.Str = string(dAtA[iNdEx:postIndex]) iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipExample(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { return ErrInvalidLengthExample } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) iNdEx += skippy } } if iNdEx > l { return io.ErrUnexpectedEOF } return nil }
- wireタグをよみこみ
- あるバイトをよみとり下位7bitをwireにたてる
- この実装だと上位1bitをどんな場合にも捨てるので
0b11111000(field_num: 15, wire_type: 0)
みたいなときにfield_numを7だと思い込まないかな?
- この実装だと上位1bitをどんな場合にも捨てるので
- 終端バイトが登場せずに64bit超えたらoverflowでエラー
- protobufのfield_numは29bitらしいので、wire typeを含めても32bitが最大な気がする。64bitとなっているのはなぜ?実装ミス?
- index値がdataのsizeを超えてもエラー
- field_numは上位1bitが立っていたら終端?field_numもvarintみたいな仕様なのかな。ちょっとわからん。あとでバイナリ見るか
- あるバイトの上位1bitがたっていたら終端byteなのでwire変数にフィールド情報つめるのを終了
- あるバイトをよみとり下位7bitをwireにたてる
- wireタグの下位3bitはtype、それ以外はfield_numとする
- field_numのフィールドにwire_typeごとのよみとりかたで値をパースする
- wire_typeがvarintなら、リトルエンディアンに並び替えて各byteの下位7bitをフィールドに詰めていく
- 上位1bitは終端判断のbitなので、立ってたら処理を抜ける
- このbit読み取り処理は完全にタグ読み取り処理と同一。コード生成側で使い回してるなさては
- 本来32bitで済むけど、最大64bit型もサポートしてるから64としちゃって全部で使いまわしてるのか。なら納得
- wire_typeがvarintでかつ符号付きだったら、↑で読み取った値をzigzag encodingに直してフィールドに詰める
int32((uint32(v) >> 1) ^ uint32(((v&1)<<31)>>31))
の意味- 左辺はvから下位1bitを捨てもの。下位1bitは正負を表現するのでこれは絶対値
- 右辺は
(v&1)<<31
のときにint32は2の補数表現なのでvが1だったら31bitシフトすると-2147483648(最上位bitのみがたったint32のマイナスの最大値)
か0
になる ((v&1)<<31)>>31
のときに31bit右シフトされる。つまり-1
か0
になる- int同士のシフト演算は算術シフトになるので最上位bitは変わらないままシフトされる
- Goの算術シフトと論理シフトの使い分けがよくわからん。仕様みてみるか
- それをuint32でキャストすると、2の補数表現での-1は符号なしだと最大値になるので、負数だと全bitがたってて、整数だと全bitが寝てる値が取れる
- 右辺の絶対値と左辺の-1か1のbitをXORとると、2の補数表現に変換された値を取り出せる
- wire_typeがlength delimitedなら、文字長よんでそのぶんのバイトをつめる
- いつものリトルエンディアンにして下位7bitを詰めるやつで文字長を読む
- 文字長がわかったら、いま読んでる位置から文字長だけ読む
- wire_typeがvarintなら、リトルエンディアンに並び替えて各byteの下位7bitをフィールドに詰めていく
これにてUnmarshal完了!
感想
- めっちゃ読みやすい
- 仕様読んだときのこういうときどうするんだろう?の疑問がいくつか解決しそう
- やっぱり仕様/実装どっちも見ることで理解度は大きく向上する
- まだちょっとわからないところがあるのでもうちょっと読み込みたい
- zigzag encodingについてちょっとまとめたいなと思った。今の理解は以下なので実装とかして深めたい
- standard intは2の補数表現で表されるから、負数になると絶対値が小さくても最上位bitが立って使用bitが常に最大になる
- zigzag encodingは2の補数表現を使わずに下位1bitの有無で正負を判断するため、絶対値が小さいときは使用bitもすくない
- その代わり正の値nと比較すると、 常に余計に1bitたつ。
- uintとかならstandard intのほうがお得だよ
- 2の補数表現と比較すると、そのまま加減算はできないけど値の表現としては絶対値がすくない負数のときはコスパがいいという。一長一短かな
- bit演算わかりづれえ...uintとintで挙動が違うのがかなり混乱する
- 2の補数表現を意識しながらbit演算をしなくてもいいよ、というサポートなんだろうか
- 逆にわかりづらい。他の言語もこうなのかな。調べてみたい