H - Nim Counting Editorial by mechanicalpenciI


問題名からも分かるようにこれはNim と呼ばれる有名な石取りゲームで、山にある石の数を \(B_1,\ldots ,B_M\) とすると、\(B_1\oplus \cdots \oplus B_M=0\) となるとき後手必勝、そうでないとき先手必勝であることが知られています。ここで、 \(\oplus\) は排他的論理和を表します。

ここで山の数 \(M\) を固定して考えます。また、長さ \(2^{16}\) の列 \(C=(C_0,\ldots, C_{2^{16}-1})\) であって、\(A_j=i\) となるような \(j\) が存在するならば \(C_i=1\) , そうでないならば \(C_i=0\) であるような列を考えます。 長さ \(2^K\)\(2\) つの列 \(X=(X_i), Y=(Y_i)\) のxorについての畳み込み \(Z=(Z_i)=X*Y\)

\[Z_k=\displaystyle\sum_{i\oplus j=k}X_iY_j\]

で定めると、先手が勝てるような石の組の数は \(C*C*\cdots(M個)\cdots*C\)の第 \(0\) 要素以外の総和です。

上の畳み込みはアダマール変換を使う事で高速に計算できます。具体的にはアダマール変換を \(H\) で表して, \(X'=(X_i')=H(X)\) , \(Y'=(Y_i')=H(Y)\) , \(Z'=(Z_i')=H(X*Y)\)とすると、 \(Z'_i=X'_iY'_i\) が成り立ち、一般に \(H(H(X))=X\) が成り立つことから、 \(X*Y=H((X_i'Y_i'))\)として求めることができます。 よって、今回の場合では \(C=H(C')\) とし、 \(C*\cdots *C=H(({C_i'}^M))\)とすればよいです。

\(M\) についての固定を解いて、同じことを考えると、 最終的な答えは\(H(({C_i'}+{C_i'}^2\cdots +{C_i'}^N))\)の第 \(0\) 要素以外の総和として求めることができます。 \({C_i'}+{C_i'}^2\cdots +{C_i'}^N\) の値は等比数列の和の公式を用いることで \(C'_i\) の値から \(O(\log N)\) で求まります。ただし、 \(C'_i\equiv 0,1 \pmod{998244353}\) の場合に注意してください。

長さ \(2^K\) に対するアダマール変換は \(O(K2^K)\) で計算でき, 今回は \(\max(A_i)<2^K\) であれば良い事から \(m_A=\max(A_i)\)として、 \(O(m_A\log m_A)\)でこの部分は計算できます。また、等比数列の和を求めるパートでは \(O(m_A\log N)\) でしたから、全体で計算量は \(O(m_A(\log m_A+\log N))\) であり、十分高速にこの問題を解く事が出来ました。

C++による実装例:

#include <bits/stdc++.h>
using namespace std;

#define N (1<<16)
#define K 16
#define MOD 998244353
#define ll long long
#define rep(i, n) for(int i = 0; i < n; ++i)
#define rep2(i, a, b) for(int i = a; i <= b; ++i)

ll modpow(ll x, int a) {
	x %= MOD;
	ll re = 1;
	while (a > 0) {
		if (a % 2 == 1)re = (re*x) % MOD;
		x = (x*x) % MOD;
		a /= 2;
	}
	return re;
}

ll inv(ll x) {
	return modpow(x, MOD - 2);
}

ll func(ll x, int n) {
	x %= MOD;
	if (x == 0)return (ll)0;
	if (x == 1)return (ll)n;
	ll re = modpow(x, n) - 1;
	re = (re*x) % MOD;
	re = (re*inv(x - 1)) % MOD;
	return re;
}


int main(void) {
	int p2[K];
	p2[0] = 1;
	rep(i, K - 1)p2[i + 1] = p2[i] * 2;
	int n, k, d;
	ll a[N];
	rep(i, N)a[i] = 0;
	ll x;
	ll ans = 0;

	cin >> n >> k;
	rep(i, k) {
		cin >> x;
		a[x] = 1;
	}
	rep(kk, K) {
		rep(i, p2[K - kk - 1]) {
			d = p2[kk + 1] * i;
			rep(j, p2[kk]) {
				x = (a[d + j] - a[d + p2[kk] + j] + MOD) % MOD;
				a[d + j] = (a[d + j] + a[d + p2[kk] + j]) % MOD;
				a[d + p2[kk] + j] = x;
			}
		}
	}
	rep(i, N)a[i] = func(a[i], n);
	rep(kk, K) {
		rep(i, p2[K - kk - 1]) {
			d = p2[kk + 1] * i;
			rep(j, p2[kk]) {
				x = (a[d + j] - a[d + p2[kk] + j] + MOD) % MOD;
				a[d + j] = (a[d + j] + a[d + p2[kk] + j]) % MOD;
				a[d + p2[kk] + j] = x;
			}
		}
	}

	rep2(i, 1, N - 1)ans += a[i];
	ans %= MOD;
	ans = (ans * (inv((ll)N))) % MOD;
	cout << ans << endl;

	return 0;
}

posted:
last update: