H - Nim Counting Editorial by en_translator
As the problem statement implies, this is a famous stone-picking game called Nim. Denoting by \(B_1,\ldots ,B_M\) the number of stones in each heap, the second player always win when \(B_1\oplus \cdots \oplus B_M=0\), or otherwise the first player wins. Here, \(\oplus\) denotes exclusive logical sum (xor).
Now let fix the number of piles \(M\). Also, we consider a sequence \(C=(C_0,\ldots, C_{2^{16}-1})\) of length \(2^16\) such that \(C_i=1\) if there exists \(j\) satisfying \(A_j=i\), or otherwise \(C_i=0\). Defining the convolution \(Z=(Z_i)=X*Y\) of two sequences \(X=(X_i), Y=(Y_i)\) of lengths \(2^K\) as
\[Z_k=\displaystyle\sum_{i\oplus j=k}X_iY_j,\]
Then the number of combinations of stones in which the first player can win is the sum of the elements except for the \(0\)-th element of \(C*C*\cdots(M\text{ times})\cdots*C\).
The convolution above can be computed fast with an Hadamard transform. Specifically, denoting the Hadamard transform by \(H\), and defining \(X'=(X_i')=H(X)\) , \(Y'=(Y_i')=H(Y)\) , and \(Z'=(Z_i')=H(X*Y)\), then \(Z'_i=X'_iY'_i\); moreover, since we have \(H(H(X))=X\) in general, we can compute \(X*Y=H((X_i'Y_i'))\). Therefore, this time it is sufficient to find \(C*\cdots *C=H(({C_i'}^M))\), where \(C=H(C')\).
Now let us unfix \(M\) and think the same way; then the final answer can be found as the sum of all the elements other than the \(0\)-th one of \(H(({C_i'}+{C_i'}^2\cdots +{C_i'}^N))\). The values of \({C_i'}+{C_i'}^2\cdots +{C_i'}^N\) can be computed from the value \(C'_i\) in an \(O(\log N)\) time with the formula of sum of geometric sequences. Note however the cases where \(C'_i\equiv 0,1 \pmod{998244353}\).
The Hadamard transform of sequence of length \(2^K\) can be computed in a total of \(O(K2^K)\). This time, \(\max(A_i)<2^K\) is sufficient, so this part can be calculated in a total of \(O(m_A\log m_A)\) time, where \(m_A=\max(A_i)\). Also, we can find the sums of geometric sequences in a total of \(O(m_A\log N)\) time, so the overall time complexity is \(O(m_A(\log m_A+\log N))\), which is fast enough to solve this problem.
Sample code in C++:
#include <bits/stdc++.h>
using namespace std;
#define N (1<<16)
#define K 16
#define MOD 998244353
#define ll long long
#define rep(i, n) for(int i = 0; i < n; ++i)
#define rep2(i, a, b) for(int i = a; i <= b; ++i)
ll modpow(ll x, int a) {
x %= MOD;
ll re = 1;
while (a > 0) {
if (a % 2 == 1)re = (re*x) % MOD;
x = (x*x) % MOD;
a /= 2;
}
return re;
}
ll inv(ll x) {
return modpow(x, MOD - 2);
}
ll func(ll x, int n) {
x %= MOD;
if (x == 0)return (ll)0;
if (x == 1)return (ll)n;
ll re = modpow(x, n) - 1;
re = (re*x) % MOD;
re = (re*inv(x - 1)) % MOD;
return re;
}
int main(void) {
int p2[K];
p2[0] = 1;
rep(i, K - 1)p2[i + 1] = p2[i] * 2;
int n, k, d;
ll a[N];
rep(i, N)a[i] = 0;
ll x;
ll ans = 0;
cin >> n >> k;
rep(i, k) {
cin >> x;
a[x] = 1;
}
rep(kk, K) {
rep(i, p2[K - kk - 1]) {
d = p2[kk + 1] * i;
rep(j, p2[kk]) {
x = (a[d + j] - a[d + p2[kk] + j] + MOD) % MOD;
a[d + j] = (a[d + j] + a[d + p2[kk] + j]) % MOD;
a[d + p2[kk] + j] = x;
}
}
}
rep(i, N)a[i] = func(a[i], n);
rep(kk, K) {
rep(i, p2[K - kk - 1]) {
d = p2[kk + 1] * i;
rep(j, p2[kk]) {
x = (a[d + j] - a[d + p2[kk] + j] + MOD) % MOD;
a[d + j] = (a[d + j] + a[d + p2[kk] + j]) % MOD;
a[d + p2[kk] + j] = x;
}
}
}
rep2(i, 1, N - 1)ans += a[i];
ans %= MOD;
ans = (ans * (inv((ll)N))) % MOD;
cout << ans << endl;
return 0;
}
posted:
last update: