はじめに
これは RSA完全理解 Advent Calendar 2021 の16日目の記事です。
前回は多倍長整数の四則演算をやりました。乗法は だったので、少し効率の良いアルゴリズムを紹介します
karatsuba法
桁数nの整数 があったとします。
これらをそれぞれ ずつに分割して
- の上位の桁のほうを
- の下位の桁のほうを
- の上位の桁のほうを
- の下位の桁のほうを
とします。 そうすると は
と表せます。
このとき計算コストがたかい乗法が4回必要ですが、ここで について以下のように整理すると乗法の回数が3回で済みます。このときの操作自体は計算結果が となるような式を当てはめただけです。計算すると成り立つことがわかります
さらにこれを通常の乗法でもとめる必要のあると説明した の式にあてはめて考えると
この式の各項をよく見ると
- は乗算で
- は乗算で
-
- の乗算はやる必要あり
- はただの減算で
- はただの減算で
- はもともと使うので すでに計算済み なのでここでわざわざ計算し直さなくて良い
- の乗算はやる必要あり
という計算の組み合わせで求められることがわかります。
本来4回必要だった乗算が3回になるので計算量下がりそうな雰囲気しますね!
これを再帰的に適用すると最終的な計算量は となり大体 くらいになる。もともとより早くなったねやったね
計算量についてくわしく
先程の最後に得られた式からそれぞれの計算回数を整理すると
- 桁の掛け算3回
- 桁の加減算4回
- ででてる加算2回はn桁2回ぶん
- と は 桁の減算だから合わせてn桁1回ぶん
- はかぶってるのが 桁だけなので実質1/2回
- もかぶってるのが 桁だけなので実質1/2回
かけ合わせたい数が2べきかどうかとかにもよるけど、筆算の乗算とくらべて定数倍が多そうなイメージが持てますね
実装
これから実装のポイントとかをかいつまんで説明しますが、完全な動作する実装が見たい方はこのcommitをみてください
まずintを超える長い桁数の初期化がだるいので、string食わせられるようにします
// SetString は入力を10進数としてscanします // 予期しない入力によって読み取りに失敗するとpanicします func (b *Int) SetString(s string) *Int { neg := false switch s[0] { case '-': neg = true s = s[1:] case '+': s = s[1:] } abs := make(digits, len(s)) for i, r := range s { d, err := strconv.ParseUint(string(r), 10, 8) if err != nil { panic(err) } abs[i] = uint8(d) } b.abs = abs b.neg = neg return b }
んで次に乗算の処理で40桁より大きければkaratsuba使うようにします(定数倍が大きいため。本来はベンチとかとってどのあたりから逆転するのかとかみて設定したほうがいいけど、まあ学習向けだしそこらへんはどんぶり勘定)
// mul は |x| * |y| の絶対値による乗算を行う func mul(x, y digits) digits { m, n := len(x), len(y) switch { case m < n: return norm(mul(y, x)) case m == 0 || n == 0: return digits{} } if m < karatsubaThreshold && n < karatsubaThreshold { return norm(basicMul(x, y)) } // karatsubaThreshold までが2のべき乗となるようにpaddingをとる k := karatsubaLen(m, karatsubaThreshold) px := leftPad(x, k-len(x)) py := leftPad(y, k-len(y)) return norm(karatsuba(px, py)) } // karatsubaLen はnが閾値まで2べきならnをそのまま返し、そうでなければ閾値まで2べきとなるようなn以上のできるだけ小さい数を返す func karatsubaLen(n, threshold int) int { var i uint = 0 for n > threshold { if n&1 == 1 { n++ } n >>= 1 i++ } n <<= i return n } // karatsuba法は定数倍が大きいので、40桁以上の乗算について適用させるようにする const karatsubaThreshold = 40 // leftPad は指定された数だけ上位の桁に0を追加します func leftPad(x digits, n int) digits { m := len(x) l := m + n abs := make(digits, l) for i := 0; i < l; i++ { if i < n { abs[i] = 0 } else { abs[i] = x[i-n] } } return abs } // rightPad は指定された数だけ下位の桁に0を追加します func rightPad(x digits, n int) digits { m := len(x) l := m + n abs := make(digits, l) for i := 0; i < l; i++ { if i < m { abs[i] = x[i] } else { abs[i] = 0 } } return abs }
karatsubaLen()
は今回実装するkaratsubaは40桁より小さくなったら通常のlong multiplicationで乗算するので、うまいことその範囲が2べきになるようにしてる
2秒くらいで思いついた処理で下1bit立ってたら2で割れないので1追加する!というのをずっとやるだけ。雑に書いた割にそこそこ効率良くて気に入ってる
leftPad()
/ rightPad()
は右か左に指定量の0paddingつめて返すだけ
つぎはkaratsuba本体
// karatsuba は karatsuba's algorithm で乗算を行う // 呼び出し側は len(x) == len(y) かつ karatsubaThreshold まで2のべき乗のサイズになっていることを保証すること func karatsuba(x, y digits) digits { m := len(x) // len(x) について、奇数/閾値以下/0のいずれかなら通常の乗算にて計算する if m&1 != 0 || m <= karatsubaThreshold || m < 2 { return basicMul(x, y) } m2 := m >> 1 x1, x0 := x[m2:], x[0:m2] y1, y0 := y[m2:], y[0:m2] x0y0 := karatsuba(x0, y0) x1y1 := karatsuba(x1, y1) // x0y1 + x1y0 = (x0-x1)(y1-y0) + (x0y0 + x1y1) となるのでその計算 // (x0+x1)(y1+y0) - (x0y0 + x1y1) の形にも整理できるが、 // 加算はcarryが発生する可能性があり後続の再帰処理にて2のべき乗のサイズとならない可能性があるため減算の形で扱っている s := 1 xd := make(digits, m2) if cmp(x0, x1) >= 0 { xd = basicSub(x0, x1) } else { s = -s xd = basicSub(x1, x0) } yd := make(digits, m2) if cmp(y1, y0) >= 0 { yd = basicSub(y1, y0) } else { s = -s yd = basicSub(y0, y1) } var p digits if s < 0 { p = basicSub(basicAdd(x0y0, x1y1), karatsuba(xd, yd)) } else { p = basicAdd(karatsuba(xd, yd), basicAdd(x0y0, x1y1)) } // x1y1*(10^m) + p*(10^m2) + x0y0 x0y0 = rightPad(x0y0, m) p = rightPad(p, m2) return basicAdd(basicAdd(x0y0, x1y1), p) }
if m&1 != 0 || m <= karatsubaThreshold || m < 2 {
のところは少し厚く判定しすぎてて、今回の実装なら多分karatsubaThresholdとの比較だけで十分なはず。安全に倒して 標準パッケージの実装に寄せた だけ
注意点としては と を計算するときに負数になる可能性があるところ。どちらか一方が負数だと はマイナスになるのでその後足し合わせるときに減算として扱っていたりする
allocはマジで適当にやってて再帰のたびに新規割当するので、このあとベンチ取るんですがなかなかエグい感じになってます
比較
ベンチマークのコードをサクッと書く
var ( x1000 = new(Int).SetString("9876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210").abs pad1000 = leftPad(x1000, karatsubaLen(len(x1000), karatsubaThreshold)-len(x1000)) x300 = new(Int).SetString("987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210").abs pad300 = leftPad(x300, karatsubaLen(len(x300), karatsubaThreshold)-len(x300)) x200 = new(Int).SetString("98765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210").abs pad200 = leftPad(x200, karatsubaLen(len(x200), karatsubaThreshold)-len(x200)) x100 = new(Int).SetString("9876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210").abs pad100 = leftPad(x100, karatsubaLen(len(x100), karatsubaThreshold)-len(x100)) ) func BenchmarkKaratsuba_len1000(b *testing.B) { for i := 0; i < b.N; i++ { karatsuba(pad1000, pad1000) } } func BenchmarkBasicMul_len1000(b *testing.B) { for i := 0; i < b.N; i++ { basicMul(x1000, x1000) } } func BenchmarkKaratsuba_len300(b *testing.B) { for i := 0; i < b.N; i++ { karatsuba(pad300, pad300) } } func BenchmarkBasicMul_len300(b *testing.B) { for i := 0; i < b.N; i++ { basicMul(x300, x300) } } func BenchmarkKaratsuba_len200(b *testing.B) { for i := 0; i < b.N; i++ { karatsuba(pad200, pad200) } } func BenchmarkBasicMul_len200(b *testing.B) { for i := 0; i < b.N; i++ { basicMul(x200, x200) } } func BenchmarkKaratsuba_len100(b *testing.B) { for i := 0; i < b.N; i++ { karatsuba(pad100, pad100) } } func BenchmarkBasicMul_len100(b *testing.B) { for i := 0; i < b.N; i++ { basicMul(x100, x100) } }
実行!
$ go test -bench . -benchmem goos: darwin goarch: amd64 pkg: github.com/convto/mycrypto/big cpu: Intel(R) Core(TM) i5-8257U CPU @ 1.40GHz BenchmarkKaratsuba_len1000-8 861 1332737 ns/op 187013 B/op 1453 allocs/op BenchmarkBasicMul_len1000-8 397 2991878 ns/op 2048 B/op 1 allocs/op BenchmarkKaratsuba_len300-8 5594 193421 ns/op 20000 B/op 157 allocs/op BenchmarkBasicMul_len300-8 4473 269398 ns/op 640 B/op 1 allocs/op BenchmarkKaratsuba_len200-8 21901 54226 ns/op 14000 B/op 157 allocs/op BenchmarkBasicMul_len200-8 10000 119414 ns/op 416 B/op 1 allocs/op BenchmarkKaratsuba_len100-8 59186 19514 ns/op 3856 B/op 49 allocs/op BenchmarkBasicMul_len100-8 40922 29108 ns/op 208 B/op 1 allocs/op PASS ok github.com/convto/mycrypto/big 11.665s
ちゃんといい感じに早くなっている!
メモリ割当がめちゃめちゃ適当で再帰するたびに割り当てられるので、alloc回数と割当量がやばいことになってるけどまあ今回は無視する。余裕があれば直すかも
ちなみにgoの標準パッケージはメモリ割当もふくめちゃんとkaratsubaしてて、しかも メモリレイアウトとかも丁寧にコメントしてる のでめちゃめちゃ参考になったりする
これにて(メモリ割当以外は)そこそこ実用的な多倍長整数の乗算が実装できたぞ!