ブログ名

競技プログラミングやお歌のお話をする高菜です。

Karatsuba 法を非再帰かつ一列に並んだ $O(n)$ サイズのメモリで実装する試み

概要

この記事では特に断りがない限り $n$ は常に $2$ 冪であるとします。

Karatsuba 法と呼ばれる長さ $n$(つまり次数 $n - 1$ 以下)の多項式同士の積を $O(n ^ { \log _ 2 (3) })$ time で計算するアルゴリズムが知られています。本記事では係数のサイズの $4n - 2$ 倍の一列に並んだメモリと $O(1)$ ワードのみを使用し、再帰関数を使わずに実装する方法をご紹介します。

おすすめの問題

Karatsuba 法でもギリギリ間に合うと思います。

問題
No.754 畳み込みの和 - yukicoder
提出(2,223 ms / 5,000 ms)
https://yukicoder.me/submissions/979356

きっかけ

マシュマロをいただきました。

アルゴリズム

$n$ を $2$ べきとして、長さ $n$ の多項式 $a, b$ の積 $c = a b$ を計算しましょう。次の仕様を持つアルゴリズムを記述することを目的にします。

入力
長さ $4n - 2$ の配列 $T$
初期状態
$T[2n - 2..]$ に step by $2$ で $a$ の係数が昇冪の順に書いてある
$T[2n - 1..]$ に step by $2$ で $b$ の係数が昇冪の順に書いてある
最終状態
$T$ に $c$ の係数が昇冪の順に書いてある

図で書くと次のような感じです。

$n = 4$ のときの初期状態・最終状態

図の記法の簡略化

次のようなルールを定めます

  • 長さ $2$ の赤い箱に bitset $S$ が書いてあるとき、左半分には $\sum _ { i \in S } a _ i$ が、右半分には $\sum _ { i \in S } b _ i$ が入っている
  • 長さ $2k$ の赤い箱に bitset 列 $S _ 0, \dots, S _ { k - 1 }$ が書いてあるとき、偶数番目には $\sum _ { i \in S _ 0 } a _ i, \dots, \sum _ { i \in S _ { k - 1 } } a _ i$ が、奇数番目には $\sum _ { i \in S _ 0 } b _ i, \dots, \sum _ { i \in S _ { k - 1 } } b _ i$ が入っている
    • つまり長さ $2$ の赤い箱が $k$ 個ならんでいてそれぞれ bitset $S _ 1, \dots, S _ k$ が書いてあるときと同じ値が入っているので、自由に併合・分割ができる
  • 長さ $2k$ の青い箱に bitset 列 $S _ 1, \dots, S _ k$ が書いてあるとき、$\left ( \sum _ { i \in S _ 0 } a _ i, \dots, \sum _ { i \in S _ { k - 1 } } a _ i \right ) \cdot \left ( \sum _ { i \in S _ 0 } b _ i, \dots, \sum _ { i \in S _ { k - 1 } } b _ i \right )$ の係数列が左詰めで書いてある
    • つまり、$S _ 1, \dots, S _ k$ の書いた長さ $k$ の赤い箱を多項式 $2$ 本とみなして積を取ったものです。

以上の記法を用いて最初の図を書き直すと次のようになります。

$n = 4$ のときの初期状態(簡略記法)

再帰版 Karatsuba 法の図解

先程定義した記法を用いると、再帰版の Karatsuba は次のように書けます。つまり再帰する前の値が赤色、再帰した後の値が青色になるわけですね。この図にある $3$ つの操作をそれぞれ拡張変換縮約と呼ぶことにします。

$n = 4$ のときの再帰版 Karatsuba 図解

再帰の中身を展開して葉まで書くと次のようになっています。葉をすべて同時にメモリに展開してしまうと $2 n ^ { \log _ 2 ( 3 ) }$ 個の係数がメモリを占めてしまうわけですが、再帰関数を使うとうまいこといって $O(n)$ memory になることも、この図を見ながら想像するとわかりやすいと思います。

$n = 4$ のときの再帰をすべて展開した図

なお赤色の箱が青色の箱に変わる部分は具体的には次のように、係数の掛け算を行っています。

$n = 4$ のときの葉の変色

一列に並んだメモリを使った再帰版

再帰呼び出しのたびに、余分なメモリを左から右に移動することにすると、次の図のように $3$ 回の再帰呼び出しでメモリを使いまわすことができます。 

一列のメモリを使った再帰版実装

疑似コードで書くとこういう感じになります。

  1. $T[n - 2..2n-2], T[2n-2..3n-2], T[3n-2..4n-2]$ にそれぞれ、低次項、高次項、和を展開する
  2. $T[0..2n - 2], T[n..3n - 2], T[2n..4n - 2]$ それぞれについてこの順に、再帰呼び出しをする
  3. $T[0..n], T[n..2n], T[2n..3n]$ に低次項、高次項、和の多項式の積が入っているので、そこから全体の積を復元する

$n = 4$ のときに葉まで書くと次のようになります。

$n = 4$ の場合の 一列メモリ Karatsuba 法

$n = 8$ のときに葉まで書くと次のようになります。ステップ数の省略のために葉の変色は省略して、中間色で書いています。

$n = 8$ の場合の 一列メモリ Karatsuba 法

非再帰化

アルゴリズムを通して青色の箱はすべて左寄せ、赤色の箱はすべて右寄せになっています。そこで青色の箱の右端を $L$、赤色の箱の左端を $R$ と置きます。さらに箱の長さは常に $2$ 以上の $2$ 冪になっている*1ので、長さ $2 \times 2 ^ i$ の青色・赤色の箱の個数をそれぞれ $l _ i, r _ i$ と置きます。このとき次の不変条件が成り立ちます。これは左辺が Karatsuba 法の $3$ つの操作、拡張、変換、縮約で不変であることからわかります。

$$ \sum _ { i \le 0 } ( l _ r + r _ i ) 3 ^ i = 2 ^ { \log _ 2 ( 3 ) } $$

また各 $l _ i, r _ i$ は常に $3$ 以下であり、さらに $\left( l _ i \right ), \left( r _ i \right )$ は常に skew tertiary number として valid な列になっていて、$3$ つの操作、拡張、変換、縮約は次のものに対応しています。

  • 拡張: $\left( r _ i \right )$ の繰り下がりのある $1$ 減少
  • 変換: 繰り上がり・繰り下がりのない $\left( r _ i \right )$ の $1$ 減少・$\left( l _ i \right )$ の $1$ 増加
  • 縮約: $\left( l _ i \right )$ の繰り上がりのある $1$ 増加

言い換えると、普通の $3$ 進法で、$\left( r _ i \right )$ から $\left( l _ i \right )$ に移しながら変換を行い、その際繰り下がりが生じるならば繰り下がりと同時に拡張を、繰り上がりが生じるならば繰り上りと同時に縮約を実行していると思うこともできます。

非再帰化するにあたっての困難は二つあり、ひとつは offset がわからないこと、もうひとつはアルゴリズムの進行度合いがわからないことです。そこで本記事では $L, R$ の値と数列 $\left( l _ i \right )$ を管理します。なお数列は bitset で管理することで高々 $2$ ワード(常識的なサイズなら $1$ ワード)に収めることができます。

また作業用に、拡張・縮約操作のサイズ $k$ を管理し、常に $2$ 冪であるように保ちます。

まとめると、管理する情報は次のものです

  • $L, R$: 青色の箱の右端、赤色の箱の右端
  • $l _ i$: 青色の長さ $2 \cdot 2 ^ i$ の箱の個数
  • $k$ 拡張・縮約操作のサイズ

疑似コードは次のようになります。

  1. $L = 0, R = 2n - 2, l = (0, 0, \dots), k = n$
  2. if $l _ n = 1$、他の項がすべて $0$ ならば終了
  3. while $k \neq 1$ do: // 拡張 
    1. $k ← k/2$
    2. $R ← R - 2k$
    3. $T[R..R + 2k], T[R + 2k..R + 4k], T[R + 4k..R + 6k]$ がそれぞれ空、低次項、高次項なので、低次項、高次項、和になるようにする
  4. // 変換 ここでは $L = R$ が成り立っている
    1. $T[L], T[L + 1]$ がそれぞれ $a, b$ の何らかの和なのでそれぞれ、その積、$0$ になるようにする
    2. $L ← L + 2, R ← R + 2$
  5. while $l _ { \log _ 2 ( k ) } = 3$ do: // 縮約
    1. $T[L - 6k..L - 4k], T[L - 4k..L - 2k], T[L - 2k..L],$ がそれぞれ低次項、高次項、和の多項式の席なので、$T[L - 6k..L - 2k]$ が全体の積に、$[L - 2k..L]$ が $0$ になるようにする
    2. $L ← 2k, l _ { \log _ 2 ( k ) } ← 0, l _ { \log _ 2 ( k ) + 1 } ← l _ { \log _ 2 ( k ) + 1 } + 1$
    3. $k ← 2k$
  6. 第2行に戻る

Rust のソースコードで書くと次のようになります。ただし使いやすさを考えて (&[T], &[T]) -> Vec<T> にしているので、最初と最後には入出力の変換や $2$ 冪拡張・縮小が含まれています。変数の対応は次のようになっています。なお最初にご紹介した提出もこのコードを使っています。

  • left: $L$
  • right: $R$
  • size: $k$
  • structure: $\left( l _ i \right)$
fn karatsuba<T>(a: &[T], b: &[T]) -> Vec<T>
where
    T: Clone + Copy + Mul<Output = T> + Add<Output = T> + Sub<Output = T> + Zero + std::fmt::Debug,
{
    let n = a.len().max(b.len()).next_power_of_two();
    let mut slice = vec![T::zero(); 4 * n - 2];
    for (i, &x) in a.iter().enumerate() {
        slice[2 * (n - 1 + i)] = x;
    }
    for (i, &x) in b.iter().enumerate() {
        slice[2 * (n - 1 + i) + 1] = x;
    }
    let mut left = 0;
    let mut right = 2 * (n - 1);
    let mut structure = 0;
    let mut size = n;
    while structure != n * n {
        while size != 1 {
            size >>= 1;
            right -= 2 * size;
            for ((i, j), k) in (right..)
                .zip(right + 2 * size..)
                .zip(right + 4 * size..)
                .take(2 * size)
            {
                slice[i] = slice[j];
                slice[j] = slice[k];
                slice[k] = slice[i] + slice[j];
            }
        }
        let (x, y) = (slice[left], slice[left + 1]);
        (slice[left], slice[left + 1]) = (x * y, T::zero());
        left += 2;
        right += 2;
        structure += 1;
        while structure & (3 * size * size) == 3 * size * size {
            for ((i, j), k) in (left - 6 * size..)
                .zip(left - 4 * size..)
                .zip(left - 2 * size..)
                .take(2 * size - 1)
            {
                slice[k] = slice[k] - slice[i] - slice[j];
            }
            for (i, j) in (left - 5 * size..)
                .zip(left - 2 * size..)
                .take(2 * size - 1)
            {
                slice[i] = slice[i] + slice[j];
                slice[j] = T::zero();
            }
            left -= 2 * size;
            structure += size * size;
            size <<= 1;
        }
    }
    slice[..2 * n - 1].to_vec()
}

*1:$a$ と $b$ が interlace されて長さが $2$ 倍になっていることに注意