Ex - A Nameless Counting Problem Editorial by yuto1115

補足 (公式解説[1]の高速化)

公式解説における「良い数列」、および kyopro_friends さんの解説における「array」を高速に求める方法を解説します。本解説では公式解説に則って良い数列と呼ぶことにします。また、数列の長さは \(n\) で固定します(そのため、実際には下記のアルゴリズムを \(n=1,2,\dots,N\) に対して繰り返し行う必要があります)。

なお、このパートは KUPC2017 で出題されたことがあり、その解説では \(O(\log n\log M)\) の解法が説明されていますが、本解説ではそれとは別のアプローチで同じ計算量を達成します。

桁 DP の場合と同様に上位の bit から見ていきます。「\(M\) より真に小さいことが確定した要素」が出現する最初の bit が \(k\) bit 目であるとします。この \(k\) を全探索します。以下、\(k\) は固定されたものとします。

条件より明らかに \(M\)\(k\) bit 目は \(1\) であり、\(k\) bit 目の時点で「\(M\) より真に小さいことが確定した要素」は「\(k\) bit 目が \(0\) である要素」と言い換えられます。「\(k\) bit 目が \(0\) である要素」の個数を \(A\) としたとき、数列の総 xor の \(k\) bit 目に対する条件から \(A\) の偶奇が定まります(逆に、偶奇以外の条件はありません)。ここで、「\(k\) bit 目が \(0\) である要素」を一つ選んで \(y\) とすると、それ以外の要素の \(k-1\) bit 目以下が何であっても、数列の総 xor に関する条件から \(y\)\(k-1\) bit 目以下がただ一通りに定まります。よって、\(A\) を固定した時、良い数列の個数は

\[\binom{n}{A}(2^k)^{A-1}(M_K+1)^{n-A}\]

と表せます(ただし、\(M_k\)\(M\)\(k\) bit 目以上を \(0\) にしたもの)。

よって、\(A\) を全探索すれば、\(k\) の全探索と合わせて \(O(n\log M)\) で解くことができます。これでも桁 DP よりは早いですが、更に高速化することができます。 \(A\) に関する条件が偶奇しかないことから、結局のところ、多項式 \(f(x)=((M_k+1)+2^kx)^n\) における \(x\) の奇数(あるいは偶数)乗の係数の総和が求まれば良いです。これは xor convolution (長さ \(2\) の巡回畳み込み) により \(O(\log n)\) で実現できるので、\(k\) の全探索と合わせて \(O(\log n\log M)\) で解くことができます。

実装の際は、\(A=0\) の場合をカウントしてしまわないように気をつけてください。

実装例 (C++) : ただし、コード中の pow_2\(2\) の累乗を保持した配列

mint calc_f(int n, ll m, ll x) {
    if (!n) return (x ? 0 : 1);
    mint res;
    bool ok = true;
    for (int k = 59; k >= 0; k--) {
        int nm = m >> k & 1;
        int nx = x >> k & 1;
        if (nm) {
            int m_k = m - (m >> k << k);
            vector<mint> v = {pow_2[k], m_k + 1};
            
            // xor convolution
            v[0] += v[1];
            v[1] = v[0] - v[1] * 2;
            v[0] = v[0].pow(n);
            v[1] = v[1].pow(n);
            v[0] += v[1];
            v[1] = v[0] - v[1] * 2;
            v[0] /= 2;
            v[1] /= 2;
            
            mint now = v[nx];
            if ((n & 1) == nx) now -= mint(m_k + 1).pow(n);
            now /= pow_2[k];
            res += now;
        }
        if (nm * (n & 1) != nx) {
            ok = false;
            break;
        }
    }
    if (ok) ++res;
    return res;
}

おまけ:上のコードを単に \(n=1,2,\dots,N\) に対して繰り返すと \(O(N\log N \log M)\) ですが、簡単に \(O(N \log M)\) に落とせます。

posted:
last update: