Official

H - Nim Counting Editorial by en_translator


As the problem statement implies, this is a famous stone-picking game called Nim. Denoting by \(B_1,\ldots ,B_M\) the number of stones in each heap, the second player always win when \(B_1\oplus \cdots \oplus B_M=0\), or otherwise the first player wins. Here, \(\oplus\) denotes exclusive logical sum (xor).

Now let fix the number of piles \(M\). Also, we consider a sequence \(C=(C_0,\ldots, C_{2^{16}-1})\) of length \(2^16\) such that \(C_i=1\) if there exists \(j\) satisfying \(A_j=i\), or otherwise \(C_i=0\). Defining the convolution \(Z=(Z_i)=X*Y\) of two sequences \(X=(X_i), Y=(Y_i)\) of lengths \(2^K\) as

\[Z_k=\displaystyle\sum_{i\oplus j=k}X_iY_j,\]

Then the number of combinations of stones in which the first player can win is the sum of the elements except for the \(0\)-th element of \(C*C*\cdots(M\text{ times})\cdots*C\).

The convolution above can be computed fast with an Hadamard transform. Specifically, denoting the Hadamard transform by \(H\), and defining \(X'=(X_i')=H(X)\) , \(Y'=(Y_i')=H(Y)\) , and \(Z'=(Z_i')=H(X*Y)\), then \(Z'_i=X'_iY'_i\); moreover, since we have \(H(H(X))=X\) in general, we can compute \(X*Y=H((X_i'Y_i'))\). Therefore, this time it is sufficient to find \(C*\cdots *C=H(({C_i'}^M))\), where \(C=H(C')\).

Now let us unfix \(M\) and think the same way; then the final answer can be found as the sum of all the elements other than the \(0\)-th one of \(H(({C_i'}+{C_i'}^2\cdots +{C_i'}^N))\). The values of \({C_i'}+{C_i'}^2\cdots +{C_i'}^N\) can be computed from the value \(C'_i\) in an \(O(\log N)\) time with the formula of sum of geometric sequences. Note however the cases where \(C'_i\equiv 0,1 \pmod{998244353}\).

The Hadamard transform of sequence of length \(2^K\) can be computed in a total of \(O(K2^K)\). This time, \(\max(A_i)<2^K\) is sufficient, so this part can be calculated in a total of \(O(m_A\log m_A)\) time, where \(m_A=\max(A_i)\). Also, we can find the sums of geometric sequences in a total of \(O(m_A\log N)\) time, so the overall time complexity is \(O(m_A(\log m_A+\log N))\), which is fast enough to solve this problem.

Sample code in C++:

#include <bits/stdc++.h>
using namespace std;

#define N (1<<16)
#define K 16
#define MOD 998244353
#define ll long long
#define rep(i, n) for(int i = 0; i < n; ++i)
#define rep2(i, a, b) for(int i = a; i <= b; ++i)

ll modpow(ll x, int a) {
	x %= MOD;
	ll re = 1;
	while (a > 0) {
		if (a % 2 == 1)re = (re*x) % MOD;
		x = (x*x) % MOD;
		a /= 2;
	}
	return re;
}

ll inv(ll x) {
	return modpow(x, MOD - 2);
}

ll func(ll x, int n) {
	x %= MOD;
	if (x == 0)return (ll)0;
	if (x == 1)return (ll)n;
	ll re = modpow(x, n) - 1;
	re = (re*x) % MOD;
	re = (re*inv(x - 1)) % MOD;
	return re;
}


int main(void) {
	int p2[K];
	p2[0] = 1;
	rep(i, K - 1)p2[i + 1] = p2[i] * 2;
	int n, k, d;
	ll a[N];
	rep(i, N)a[i] = 0;
	ll x;
	ll ans = 0;

	cin >> n >> k;
	rep(i, k) {
		cin >> x;
		a[x] = 1;
	}
	rep(kk, K) {
		rep(i, p2[K - kk - 1]) {
			d = p2[kk + 1] * i;
			rep(j, p2[kk]) {
				x = (a[d + j] - a[d + p2[kk] + j] + MOD) % MOD;
				a[d + j] = (a[d + j] + a[d + p2[kk] + j]) % MOD;
				a[d + p2[kk] + j] = x;
			}
		}
	}
	rep(i, N)a[i] = func(a[i], n);
	rep(kk, K) {
		rep(i, p2[K - kk - 1]) {
			d = p2[kk + 1] * i;
			rep(j, p2[kk]) {
				x = (a[d + j] - a[d + p2[kk] + j] + MOD) % MOD;
				a[d + j] = (a[d + j] + a[d + p2[kk] + j]) % MOD;
				a[d + p2[kk] + j] = x;
			}
		}
	}

	rep2(i, 1, N - 1)ans += a[i];
	ans %= MOD;
	ans = (ans * (inv((ll)N))) % MOD;
	cout << ans << endl;

	return 0;
}

posted:
last update: