G - Smaller Sum Editorial by ngtkana


この問題は wavelet 行列に(必ずしも省メモリでない)データ構造を乗せる手法で解くことができます。

追加のデータ構造が省メモリでないため当然 wavelet 行列内の簡潔ビットベクトルも普通の累積和等で置き換えてもメモリ使用量への影響が少なく、また速そうなのでこの問題でもそうすると良いです。

以下その詳細をご説明しましょう。

\(A\) の順列列 \(B^{(0)}, B^{(1)}, \dots, B^{(p - 1)}\) の定義

まず \(A _ i ≤ 2 ^ p - 2\) を満たす \(p\) (高さ)を取ります。\( ≤ 2 ^ p - 1\) にしていない理由はクエリ処理で場合分けを減らすためです。さらに \(A\) の順列列 \(B^{(0)}, B^{(1)}, \dots, B^{(p - 1)}\) を次のように定義します。

  • \(B^{(0)}\)\(A\) を、第 \(p\) ビットの値で逆順に安定ソートしたもの
  • \(B^{(i)}\)\(B^{(i - 1)}\) を、第 \(p - i\) ビットの値で逆順に安定ソートしたもの

なおこのソートはキーが \(2\) 値のため線形時間で行うことができます。ちなみにすぐにわかるように、最終段の \(B^{(p - 1)}\) では要素は bit-reversed 順序に関して逆順にソートされた状態になっています。(逆順にしたのは私の好みで、これはどちらでも大丈夫です。)

さらに各 \(B^{(0)}, B^{(1)}, \dots, B^{(p - 1)}\) の累積和を取るなどして、範囲総和クエリに答えられる状態にしておきましょう。(ここでは記号は導入しません。)

rank \(R^{(0)}, R^{(1)}, \dots, R^{(p - 1)}\) の定義、通常の wavelet 行列との関連

さらにそれぞれ \(N + 1\) 項からなるような数列列 \(R^{(0)}, R^{(1)}, \dots, R^{(p - 1)}\) を、

  • \(R^{(0)} _j\)\(A _ 0, \dots, A _ { j - 1 }\) のうち第 \(p\) ビットがたっているもの(ソートで \(B^{(0)}\) の前半に送られたもの)の個数
  • \(R^{(i)} _ j\)\(B^{(i - 1)} _ 0, \dots, B^{(i - 1)} _ { j - 1 }\) のうち第 \(p - i\) ビットがたっているもの(ソートで \(B^{(i)}\) の前半に送られたもの)の個数

と定義します。通常の wavelet 行列ではむしろこの \(R\) に相当するものが rank と言われ、簡潔データ構造で管理されます。

クエリ処理

通常の wavelet 行列でいう \(\mathtt{rangefreq}\) クエリ(\((\textrm{添え字}, \textrm{値})\) の直積範囲カウントクエリ)と同様では、クエリ範囲を列 \(B^{(0)}, B^{(1)}, \dots, B^{(p - 1)}\) (ただし通常の wavelet 行列では明示的に管理されていない)における \(O(p)\) 個の連続部分列に分解し、その長さを足し合わせていましたね。この問題でも基本的には同様なのですが、長さの代わりにその subrange における \(B^{(i)}\) その総和を足し合わせることにすればよいです。

ちなみにクエリは \(L ≤ i < R, A_i < x\) の形に書き換えると楽なのですが、このとき \(x\)\(2^p\) 以上になるとちょっと面倒です(下位 \(p\) ビットだけを見て処理できなくなる)から \(2 ^ p - 1\) 以下になるように clamp しましょう。そうすると \(A\) の要素はすべて \(2 ^ p - 2\) 以下であってくれないと困りますね。これが最初の \(p\) の定義のところで謎に思ったより \(1\) 小さかった理由だったりします。値の条件 \(A_i < x\) が上界しかないので実装もかなり楽になりますね。

公式解説のマージソート木との関連(閑話休題)

マージソート木は根側から見ると、最初の状態で添字ではなくて値でソートされていて、葉側に降りるにつれて添字の大小で partition されていきます。ただし wavelet 行列とは異なり、再帰的にどんどん細分化されていくタイプなので、wavelet 行列というよりは wavelet 木(wavelet 行列の元になったアイデアで、wavelet 行列とは構造が若干違います)に近いですね。

このときの partition に対応する rank を覚えておくことでマージソート木でも似たようなことができて、公式解説の log も落ちるのではないかと思いますので、どなたかやってみていただけると私はたいへん嬉しく(あの?)て。具体的にはセグメント木の各ノードの中の数列内で二分探索をする必要がなくなって、根内の数列の範囲から芋づる式にすべての箇所の範囲がわかるはずですのでぜひです。

計算量

構築 \(O(N \log A)\) 時間、クエリ \(O( \log A)\) 時間なので、全部で \(O((N + Q) \log A)\) 時間です。(ただし \(A\) は数列 \(A\) の上界)

実装

Wavelet 行列は、簡潔データ構造にさえこだわらなければ実は簡単に書けます。以下、ライブラリを使わずに書きましたのでぜひともご参考にです。

\(O((N + Q) \log A)\) 時間とは思えない貧相な実行時間です。私の勘違いだったりよくない実装だったりを発見したときには ngtkana まで教えていただけると幸いです。

提出 (819 ms):https://atcoder.jp/contests/abc339/submissions/51210104

use itertools::izip;
use proconio::input;

fn main() {
    input! {
        n: usize,
        mut a: [usize; n],
        q: usize,
        queries: [(usize, usize, usize); q],
    }
    let lg = (a.iter().copied().max().unwrap() + 2)
        .next_power_of_two()
        .trailing_zeros() as usize;
    let mut cums = Vec::new();
    let mut ranks = Vec::new();
    for i in (0..lg).rev() {
        // ランクを計算
        let mut rank = vec![0; n + 1];
        for (j, &x) in a.iter().enumerate() {
            rank[j + 1] = rank[j] + ((x >> i) & 1);
        }
        // a を第 i ビットをキーに逆順ソート
        let mut swp = vec![usize::MAX; n];
        for (j, &x) in a.iter().enumerate() {
            swp[match (x >> i) & 1 {
                0 => rank[n] + j - rank[j],
                1 => rank[j],
                _ => unreachable!(),
            }] = a[j];
        }
        // a の累積和を計算
        a = swp;
        let mut cum = vec![0; n + 1];
        for (j, &x) in a.iter().enumerate() {
            cum[j + 1] = cum[j] + x;
        }
        ranks.push(rank);
        cums.push(cum);
    }
    let mut rng = 0;
    for (alpha, beta, gamma) in queries {
        let mut l = (alpha ^ rng) - 1;
        let mut r = beta ^ rng;
        let x = ((gamma ^ rng) + 1).min((1 << lg) - 1);
        let mut ans = 0;
        for (i, cum, rank) in izip!((0..lg).rev(), &cums, &ranks) {
            // (l, r) が次の段の前半・後半でどこに相当するかを計算
            let l1 = rank[l];
            let r1 = rank[r];
            let l0 = rank[n] + l - l1;
            let r0 = rank[n] + r - r1;
            (l, r) = match x >> i & 1 {
                0 => (l0, r0),
                1 => {
                    ans += cum[r0] - cum[l0];
                    (l1, r1)
                }
                _ => unreachable!(),
            };
        }
        println!("{ans}");
        rng = ans;
    }
}

posted:
last update: