Official

E - popcount <= 2 Editorial by evima


First, for any value \(x\) where \(1 \leq x < 2^M\), if we set \(A \leftarrow (A_1 \oplus x, A_2 \oplus x, \ldots, A_N \oplus x)\), whether \(A\) satisfies the conditions remains unchanged. Therefore, the answer is the number of ways when we add the condition \(A_1 = 0\), multiplied by \(2^M\). Hereafter, we consider the case where we add the condition \(A_1 = 0\).

From \(A_1 = 0\), we have \(\operatorname{popcount}(A_i) \leq 2\) for each \(i\).

Let \(S = \{A_i \mid \operatorname{popcount}(A_i) = 2\}\).

[1] When \(|S| = 0\)

For all \(i\), we have \(\operatorname{popcount}(A_i) \leq 1\). In this case, \(\operatorname{popcount}(A_i \oplus A_j) \leq \operatorname{popcount}(A_i) + \operatorname{popcount}(A_j) \leq 2\) holds, so the only condition is \(\operatorname{popcount}(A_i) \leq 1\). There are \((M+1)^{N-1}\) sequences \(A\) that satisfy this condition.

[2] When \(|S| = 1\)

Let \(S = \{X\}\).

If \(\operatorname{popcount}(A_i) = 1\), then the unique bit that is set in \(A_i\) must also be set in \(X\). Conversely, if the above condition holds for all \(i\), then \(A\) satisfies the conditions. Therefore:

  • There are \(\binom{M}{2}\) ways to determine \(X\)
  • There are \(3\) values of \(A_i\) that satisfy \(\operatorname{popcount}(A_i) \leq 1\)
  • There is \(1\) value of \(A_i\) that satisfies \(\operatorname{popcount}(A_i) = 2\)
  • At least one \(A_i\) satisfying \(\operatorname{popcount}(A_i) = 2\) must exist

So there are \(\binom{M}{2}(4^{N-1} - 3^{N-1})\) sequences \(A\) that satisfy the conditions.

[3] When \(|S| \geq 2\)

[3-a] When the bitwise AND of all elements in \(S\) is \(0\)

In this case, using three positive integers \(k_1 < k_2 < k_3\), we can represent \(S = \{2^{k_1} + 2^{k_2}, 2^{k_1} + 2^{k_3}, 2^{k_2} + 2^{k_3}\}\). Therefore:

  • There are \(\binom{M}{3}\) ways to determine \(k_1, k_2, k_3\)
  • The allowed values for \(A_i\) are \(0, 2^{k_1} + 2^{k_2}, 2^{k_1} + 2^{k_3}, 2^{k_2} + 2^{k_3}\), which are \(4\) values
  • All of \(A_i = 2^{k_1} + 2^{k_2}, 2^{k_1} + 2^{k_3}, 2^{k_2} + 2^{k_3}\) must exist at least once

So there are \(\binom{M}{3}(4^{N-1} - 3 \times 3^{N-1} + 3 \times 2^{N-1} - 1)\) sequences \(A\) that satisfy the conditions.

[3-b] When the bitwise AND of all elements in \(S\) is not \(0\)

Consider the case where all elements of \(S\) share a certain bit (denoted as \(2^k\)). In this case, if \(x, y \in S\), then \(\operatorname{popcount}(x \oplus y) \leq 2\) holds.

If \(k' \neq k\), then there exists some \(x \in S\) such that \(\operatorname{popcount}(2^{k'} \oplus x) > 2\), and conversely, for all \(x \in S\), \(\operatorname{popcount}(2^k \oplus x) \leq 2\) holds. From this, we can deduce that if \(\operatorname{popcount}(A_i) = 1\), then \(A_i = 2^k\).

Therefore:

  • There are \(M\) ways to determine \(k\)
  • There are \(M-1\) types of \(A_i\) with \(\operatorname{popcount}(A_i) = 2\)
  • There are \(2\) types of \(A_i\) with \(\operatorname{popcount}(A_i) \leq 1\)
  • At least two \(A_i\) with \(\operatorname{popcount}(A_i) = 2\) must exist

So there are \(M\left((M+1)^{N-1} - 2^{N-1} - (M-1)(3^{N-1} - 2^{N-1})\right)\) sequences \(A\) that satisfy the conditions.

By combining the above and implementing appropriately, we can solve this problem correctly. The computational complexity is \(O(\log N)\) per test case.

Sample Implementation (Python3)

mod = 998244353
two_inv = pow(2, mod - 2, mod)
six_inv = pow(6, mod - 2, mod)


def c2(n):
    return n * (n - 1) % mod * two_inv % mod


def c3(n):
    return n * (n - 1) % mod * (n - 2) % mod * six_inv % mod


def solve(n, m):
    n2 = pow(2, n - 1, mod)
    n3 = pow(3, n - 1, mod)
    n4 = pow(4, n - 1, mod)
    nm1 = pow(m + 1, n - 1, mod)
    ans = 0
    # 0
    ans += nm1
    # 1
    ans += c2(m) * (n4 - n3) % mod
    # 2
    ko = n4 - 3 * n3 + 3 * n2 - 1
    ans += c3(m) * ko % mod
    # >= 2
    ko = nm1 - n2 - (m - 1) * (n3 - n2)
    ans += m * ko % mod
    ans %= mod
    ans *= pow(2, m, mod)
    ans %= mod
    return ans

for _ in range(int(input())):
    n, m = map(int, input().split())
    print(solve(n, m))

posted:
last update: