ちりもつもればミルキーウェイ

好奇心に可処分時間が奪われる

karatsuba法

はじめに

これは RSA完全理解 Advent Calendar 2021 の16日目の記事です。

前回は多倍長整数の四則演算をやりました。乗法は \mathcal{O}(n^2) だったので、少し効率の良いアルゴリズムを紹介します

karatsuba法

桁数nの整数 x, y があったとします。
これらをそれぞれ \dfrac n 2 ずつに分割して

  • x の上位の桁のほうを x_0
  • x の下位の桁のほうを x_1
  • x の上位の桁のほうを x_0
  • x の下位の桁のほうを x_1

とします。 そうすると x\times y

(x_0 + x_1)(y_0 + y_1) = x_0y_0 \times 10^n + (x_0y_1 + x_1y_0)\times 10^ {\frac n 2} + x_1y_1 \ \dots i

と表せます。

このとき計算コストがたかい乗法が4回必要ですが、ここで x_0y_1 + x_1y_0 について以下のように整理すると乗法の回数が3回で済みます。このときの操作自体は計算結果が x_0y_1 + x_1y_0 となるような式を当てはめただけです。計算すると成り立つことがわかります

x_0y_1 + x_1y_0 = (x_0-x_1)(y_1-y_0) + (x_0y_0 + x_1y_1)

さらにこれを通常の乗法でもとめる必要のあると説明した i の式にあてはめて考えると

\displaystyle(x_0 + x_1)(y_0 + y_1) \\ = x_0y_0 \times 10^n + (x_0y_1 + x_1y_0)\times 10^ {\frac n 2} + x_1y_1 \\ = x_0y_0 \times 10^n  + x_1y_1 + ( (x_0 - x_1)(y_1-y_0) + (x_0y_0 + x_1y_1) )\times 10^ {\frac n 2}

この式の各項をよく見ると

  • x0y0 は乗算で \mathcal{O}( (\dfrac n 2)^2)
  • x1y1 は乗算で \mathcal{O}( (\dfrac n 2)^2)
  • ( (x_0 - x_1)(y_1-y_0) + (x_0y_0 + x_1y_1) )
    • (x_0 - x_1)(y_1-y_0) の乗算はやる必要あり \mathcal{O}( (\dfrac n 2)^2)
      • x0-x1 はただの減算で \mathcal{O}(n)
      • y1-y0 はただの減算で \mathcal{O}(n)
    • (x_0y_0 + x_1y_1) はもともと使うので すでに計算済み なのでここでわざわざ計算し直さなくて良い

という計算の組み合わせで求められることがわかります。

本来4回必要だった乗算が3回になるので計算量下がりそうな雰囲気しますね!
これを再帰的に適用すると最終的な計算量は \mathcal{O}(n^{log_2 3}) となり大体 \mathcal{O}(n^{1.58}) くらいになる。もともとより早くなったねやったね

計算量についてくわしく

先程の最後に得られた式からそれぞれの計算回数を整理すると

  • \dfrac n 2 桁の掛け算3回
  • n 桁の加減算4回
    • ( (x_0 - x_1)(y_1-y_0) + (x_0y_0 + x_1y_1) ) ででてる加算2回はn桁2回ぶん
    • x_0 - x_1y_1-y_0\dfrac n 2 桁の減算だから合わせてn桁1回ぶん
    • x_0y_0 \times 10^n + ( (x_0 - x_1)(y_1-y_0) + (x_0y_0 + x_1y_1) )\times 10^ {\frac n 2} はかぶってるのが \dfrac n 2 桁だけなので実質1/2回
    • x_1y_1 + ( (x_0 - x_1)(y_1-y_0) + (x_0y_0 + x_1y_1) )\times 10^ {\frac n 2} もかぶってるのが \dfrac n 2 桁だけなので実質1/2回

かけ合わせたい数が2べきかどうかとかにもよるけど、筆算の乗算とくらべて定数倍が多そうなイメージが持てますね

実装

これから実装のポイントとかをかいつまんで説明しますが、完全な動作する実装が見たい方はこのcommitをみてください

github.com

まずintを超える長い桁数の初期化がだるいので、string食わせられるようにします

// SetString は入力を10進数としてscanします
// 予期しない入力によって読み取りに失敗するとpanicします
func (b *Int) SetString(s string) *Int {
    neg := false
    switch s[0] {
    case '-':
        neg = true
        s = s[1:]
    case '+':
        s = s[1:]
    }
    abs := make(digits, len(s))
    for i, r := range s {
        d, err := strconv.ParseUint(string(r), 10, 8)
        if err != nil {
            panic(err)
        }
        abs[i] = uint8(d)
    }
    b.abs = abs
    b.neg = neg
    return b
}

んで次に乗算の処理で40桁より大きければkaratsuba使うようにします(定数倍が大きいため。本来はベンチとかとってどのあたりから逆転するのかとかみて設定したほうがいいけど、まあ学習向けだしそこらへんはどんぶり勘定)

// mul は |x| * |y| の絶対値による乗算を行う
func mul(x, y digits) digits {
    m, n := len(x), len(y)
    switch {
    case m < n:
        return norm(mul(y, x))
    case m == 0 || n == 0:
        return digits{}
    }

    if m < karatsubaThreshold && n < karatsubaThreshold {
        return norm(basicMul(x, y))
    }

    // karatsubaThreshold までが2のべき乗となるようにpaddingをとる
    k := karatsubaLen(m, karatsubaThreshold)
    px := leftPad(x, k-len(x))
    py := leftPad(y, k-len(y))
    return norm(karatsuba(px, py))
}

// karatsubaLen はnが閾値まで2べきならnをそのまま返し、そうでなければ閾値まで2べきとなるようなn以上のできるだけ小さい数を返す
func karatsubaLen(n, threshold int) int {
    var i uint = 0
    for n > threshold {
        if n&1 == 1 {
            n++
        }
        n >>= 1
        i++
    }
    n <<= i
    return n
}

// karatsuba法は定数倍が大きいので、40桁以上の乗算について適用させるようにする
const karatsubaThreshold = 40

// leftPad は指定された数だけ上位の桁に0を追加します
func leftPad(x digits, n int) digits {
    m := len(x)
    l := m + n
    abs := make(digits, l)
    for i := 0; i < l; i++ {
        if i < n {
            abs[i] = 0
        } else {
            abs[i] = x[i-n]
        }
    }
    return abs
}

// rightPad は指定された数だけ下位の桁に0を追加します
func rightPad(x digits, n int) digits {
    m := len(x)
    l := m + n
    abs := make(digits, l)
    for i := 0; i < l; i++ {
        if i < m {
            abs[i] = x[i]
        } else {
            abs[i] = 0
        }
    }
    return abs
}

karatsubaLen() は今回実装するkaratsubaは40桁より小さくなったら通常のlong multiplicationで乗算するので、うまいことその範囲が2べきになるようにしてる

2秒くらいで思いついた処理で下1bit立ってたら2で割れないので1追加する!というのをずっとやるだけ。雑に書いた割にそこそこ効率良くて気に入ってる

leftPad() / rightPad() は右か左に指定量の0paddingつめて返すだけ

つぎはkaratsuba本体

// karatsuba は karatsuba's algorithm で乗算を行う
// 呼び出し側は len(x) == len(y) かつ karatsubaThreshold まで2のべき乗のサイズになっていることを保証すること
func karatsuba(x, y digits) digits {
    m := len(x)
    // len(x) について、奇数/閾値以下/0のいずれかなら通常の乗算にて計算する
    if m&1 != 0 || m <= karatsubaThreshold || m < 2 {
        return basicMul(x, y)
    }
    m2 := m >> 1
    x1, x0 := x[m2:], x[0:m2]
    y1, y0 := y[m2:], y[0:m2]

    x0y0 := karatsuba(x0, y0)
    x1y1 := karatsuba(x1, y1)

    // x0y1 + x1y0 = (x0-x1)(y1-y0) + (x0y0 + x1y1) となるのでその計算
    // (x0+x1)(y1+y0) - (x0y0 + x1y1) の形にも整理できるが、
    // 加算はcarryが発生する可能性があり後続の再帰処理にて2のべき乗のサイズとならない可能性があるため減算の形で扱っている
    s := 1
    xd := make(digits, m2)
    if cmp(x0, x1) >= 0 {
        xd = basicSub(x0, x1)
    } else {
        s = -s 
        xd = basicSub(x1, x0)
    }
    yd := make(digits, m2)
    if cmp(y1, y0) >= 0 {
        yd = basicSub(y1, y0)
    } else {
        s = -s
        yd = basicSub(y0, y1)
    }
    var p digits
    if s < 0 {
        p = basicSub(basicAdd(x0y0, x1y1), karatsuba(xd, yd))
    } else {
        p = basicAdd(karatsuba(xd, yd), basicAdd(x0y0, x1y1))
    }

    // x1y1*(10^m) + p*(10^m2) + x0y0
    x0y0 = rightPad(x0y0, m)
    p = rightPad(p, m2)
    return basicAdd(basicAdd(x0y0, x1y1), p)
}

if m&1 != 0 || m <= karatsubaThreshold || m < 2 { のところは少し厚く判定しすぎてて、今回の実装なら多分karatsubaThresholdとの比較だけで十分なはず。安全に倒して 標準パッケージの実装に寄せた だけ

注意点としては x0-x1y1-y0 を計算するときに負数になる可能性があるところ。どちらか一方が負数だと (x0-x1)(y1-y0) はマイナスになるのでその後足し合わせるときに減算として扱っていたりする

allocはマジで適当にやってて再帰のたびに新規割当するので、このあとベンチ取るんですがなかなかエグい感じになってます

比較

ベンチマークのコードをサクッと書く

var (
    x1000   = new(Int).SetString("9876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210").abs
    pad1000 = leftPad(x1000, karatsubaLen(len(x1000), karatsubaThreshold)-len(x1000))
    x300    = new(Int).SetString("987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210").abs
    pad300  = leftPad(x300, karatsubaLen(len(x300), karatsubaThreshold)-len(x300))
    x200    = new(Int).SetString("98765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210").abs
    pad200  = leftPad(x200, karatsubaLen(len(x200), karatsubaThreshold)-len(x200))
    x100    = new(Int).SetString("9876543210987654321098765432109876543210987654321098765432109876543210987654321098765432109876543210").abs
    pad100  = leftPad(x100, karatsubaLen(len(x100), karatsubaThreshold)-len(x100))
)

func BenchmarkKaratsuba_len1000(b *testing.B) {
    for i := 0; i < b.N; i++ {
        karatsuba(pad1000, pad1000)
    }
}

func BenchmarkBasicMul_len1000(b *testing.B) {
    for i := 0; i < b.N; i++ {
        basicMul(x1000, x1000)
    }
}

func BenchmarkKaratsuba_len300(b *testing.B) {
    for i := 0; i < b.N; i++ {
        karatsuba(pad300, pad300)
    }
}

func BenchmarkBasicMul_len300(b *testing.B) {
    for i := 0; i < b.N; i++ {
        basicMul(x300, x300)
    }
}

func BenchmarkKaratsuba_len200(b *testing.B) {
    for i := 0; i < b.N; i++ {
        karatsuba(pad200, pad200)
    }
}

func BenchmarkBasicMul_len200(b *testing.B) {
    for i := 0; i < b.N; i++ {
        basicMul(x200, x200)
    }
}

func BenchmarkKaratsuba_len100(b *testing.B) {
    for i := 0; i < b.N; i++ {
        karatsuba(pad100, pad100)
    }
}

func BenchmarkBasicMul_len100(b *testing.B) {
    for i := 0; i < b.N; i++ {
        basicMul(x100, x100)
    }
}

実行!

$ go test -bench . -benchmem
goos: darwin
goarch: amd64
pkg: github.com/convto/mycrypto/big
cpu: Intel(R) Core(TM) i5-8257U CPU @ 1.40GHz
BenchmarkKaratsuba_len1000-8         861       1332737 ns/op      187013 B/op       1453 allocs/op
BenchmarkBasicMul_len1000-8          397       2991878 ns/op        2048 B/op          1 allocs/op
BenchmarkKaratsuba_len300-8         5594        193421 ns/op       20000 B/op        157 allocs/op
BenchmarkBasicMul_len300-8          4473        269398 ns/op         640 B/op          1 allocs/op
BenchmarkKaratsuba_len200-8        21901         54226 ns/op       14000 B/op        157 allocs/op
BenchmarkBasicMul_len200-8         10000        119414 ns/op         416 B/op          1 allocs/op
BenchmarkKaratsuba_len100-8        59186         19514 ns/op        3856 B/op         49 allocs/op
BenchmarkBasicMul_len100-8         40922         29108 ns/op         208 B/op          1 allocs/op
PASS
ok      github.com/convto/mycrypto/big  11.665s

ちゃんといい感じに早くなっている!

メモリ割当がめちゃめちゃ適当で再帰するたびに割り当てられるので、alloc回数と割当量がやばいことになってるけどまあ今回は無視する。余裕があれば直すかも

ちなみにgoの標準パッケージはメモリ割当もふくめちゃんとkaratsubaしてて、しかも メモリレイアウトとかも丁寧にコメントしてる のでめちゃめちゃ参考になったりする

これにて(メモリ割当以外は)そこそこ実用的な多倍長整数の乗算が実装できたぞ!