G - Sum of (XOR^K or 0) Editorial by en_translator
This editorial requires the knowledge of XOR convolution and Hadamard transform. See also the editorial of a past problem of ABC.
Let \(L = 20\). For \(i=0,1,\ldots,2^L-1\), define \(C_i\) as “the number of subsequences of \(A\) whose length is a multiple of \(M\) and the total XOR equals \(i\).” Since the answer is represented as \(\displaystyle \sum_{i=0}^{2^L-1}C_i\cdot i^K\), it is sufficient to find \(C_0,C_1,\ldots,C_{2^L-1}\).
Let \(S\) be the set of the sequences of length \(2^L\) where each term is a polynomial.
Define \(v_i\in S\ (i=1,2,\ldots,N)\) as follows:
- The \(0\)-th term of \(v_i\) is \(1\), the \(A_i\)-th term is \(x\), and the other is \(0\).
- However, if \(A_i=0\), then the \(0\)-th term is \(1+x\), and the other is \(0\).
Also, for \(a,b\in S\), we define the XOR convolution \(a\ast b\in S\) as follows (where \(\times\) denotes an ordinary multiplication of polynomials):
\[(a\ast b)_k=\displaystyle\sum_{i\oplus j=k}a_i\times b_j\]
Let us denote by \(V=v_1\ast v_2 \ast \dots \ast v_N\) the XOR convolution of \(v_1,v_2,\ldots,v_N\). Then the coefficient on \(x^j\) of the \(i\)-th term of \(x^j\) (\([x^j]V_i\)) equals the number of length-\(j\) subsequences of \(A\) with a total XOR of \(i\). Thus, \(C_i=[x^0]V_i+[x^M]V_i+[x^{2M}]V_i+\dots=[x^0]V_i\ (\mathrm{mod}\ 1-x^M)\).
A famous method for a fast XOR convolution is Hadamard conversion. Denoting Hadamard conversion by \(H\), we have the following property on \(a,b\in S\):
\[H(a\ast b)_i=H(a)_i\times H(b)_i.\]
By applying the equation above multiple times, we have \(H(V)_i=H(v_1)_i\times H(v_2)_i\times \dots \times H(v_N)_i\). Now let us observe the structure of \(H(v_i)\).
The \((i,j)\) component of the Hadamard matrix is \((-1)^{\mathrm{popcount}(i\ \mathrm{AND}\ j)}\). If you apply the Hadamard conversion on a sequence where only one element is \(1\) and the other is \(0\), the resulting terms become either \(1\) or \(-1\). (For example, \(H((0,0,0,1,0,0,0,0))=(1,-1,-1,1,1,-1,-1,1)\).) Especially, if it is applied to a sequence where only the \(0\)-th term is \(1\), the conversion makes all the term \(1\). Since \(v_i\) has \(1\) at the \(0\)-th term and \(x\) at the \(A_i\)-th term, the terms of \(H(v_i)\) becomes either \(1+x\) or \(1-x\). (For example, \(H((1,0,0,x,0,0,0,0))=(1+x, 1-x, 1-x, 1+x, 1+x, 1-x, 1-x, 1+x)\).) Since \(H(V)\) is a element-wise product of \(H(v_1),H(v_2),\ldots,H(v_N)\), one can represent \(H(V)_i\) using an integer \(B_i\) as \((1+x)^{B_i}(1-x)^{N-B_i}\).
While computing this \(B\) naively according to the definition of Hadamard conversion costs \(O(4^L)\), the divide-and-conquer trick enables us to compute it in \(O(L 2^L)\) time. Alternatively, we may use \(B_i=\dfrac{H(\mathrm{cnt})_i+N}{2}\), where \(\mathrm{cnt}_i\) is the number of occurrences of \(i\) in \(A\).
Once we have obtained \(B\), all that left is to find \([x^0](1+x)^{B_i}(1-x)^{N-B_i} (\mathrm{mod}\ 1-x^M)\) and apply Hadamard conversion to find \(C\). \([x^0](1+x)^{B_i}(1-x)^{N-B_i} (\mathrm{mod}\ 1-x^M)\) can be evaluated in a total of \(O(NM)\) time by precalculating \((1+x)^n,(1-x)^n (\mathrm{mod}\ 1-x^M)\).
Therefore, the problem can be solved in a total of \(O(NM + 2^L(L +\log K))\) time.
mod = 998244353
def hadamard(a, m):
for k in range(m):
i = 1 << k
for j in range(1 << m):
if not i & j:
a[j], a[i | j] = (a[j] + a[i | j]), (a[j] - a[i | j])
N, M, K = map(int, input().split())
A = list(map(int, input().split()))
L = 20
cnt = [0] * (1 << L)
for i in range(N):
cnt[A[i]] += 1
F = [[0] * M for i in range(N + 1)] # (1 + x) ^ n (mod 1 - x ^ M)
G = [[0] * M for i in range(N + 1)] # (1 - x) ^ n (mod 1 - x ^ M)
F[0][0] = 1
G[0][0] = 1
for i in range(N):
for j in range(M):
F[i + 1][j] = (F[i][j] + F[i][j - 1]) % mod
G[i + 1][j] = (G[i][j] - G[i][j - 1]) % mod
res = [0] * (N + 1)
for i in range(N + 1):
for j in range(M):
res[i] += F[i][j] * G[N - i][-j]
res[i] %= mod
hadamard(cnt, L)
B = [(cnt[i] + N) // 2 for i in range(1 << L)]
C = [res[B[i]] for i in range(1 << L)]
hadamard(C, L)
inv = pow(1 << L, mod - 2, mod)
ans = 0
for i in range(1 << L):
C[i] = C[i] % mod * inv % mod
ans += C[i] * pow(i, K, mod)
ans %= mod
print(ans)
Evaluating \(B\) with divide-and-conquer
def calc(off, k):
if k == 0:
return ([cnt[off]], cnt[off])
H0, c0 = calc(off, k - 1)
H1, c1 = calc(off | (1 << (k - 1)), k - 1)
H = [0] * (1 << k)
for i in range(1 << (k - 1)):
H[i] = H0[i] + H1[i]
H[i | (1 << (k - 1))] = H0[i] + c1 - H1[i]
return (H, c0 + c1)
B = calc(0, L)[0]
posted:
last update: