Official

H - Maximize XOR Editorial by en_translator


Notice the unusual constraints \(\dbinom{N}{K}\leq 10^6\). Since it is guaranteed that there are at most \(10^6\) ways to choose \(K\) distinct elements from \(A\), it is possible to enumerate all such choices with an appropriate algorithm.

One approach for the enumeration would be as follows:

def func(x: list[int], i: int):
    if len(x) == K:
        # Enumerate sequences `x` of length K
        return
    if i == N:
        return
    func(x, i + 1)
    func(x + [i], i + 1)

func([], 0)

However, this will not finish within the execution time if \(K\) is large, because the function func is called about \(\displaystyle \sum_{i=0}^K \dbinom{N}{i}\) times, where \(\dbinom{N}{\lfloor N/2\rfloor}\) can be enormous even if \(\dbinom{N}{K}\) is small.

For example, \((N,K)=(100,98)\) satisfies the constraints because \(\dbinom{N}{K}=4950\lt 10^6\), but \(\dbinom{100}{50}\) can be as large as \(10^{29}\).

To avoid this, to finish the process within the time limit when \(K\) is large, one can inspect \((N-K)\) elements that are not chosen, instead of chosen \(K\) elements. Specifically, one can do the following branching:

  • If \(K\leq N-K\): we can naively do the exhaustive search.
  • If \(K\gt N-K\): first find the total XOR of the elements, and exhaustively search the remaining \((N-K)\) elements. This way, we can find the XOR of the chosen \(K\) elements fast.

The time complexity is \(O\left(\dbinom{N}{K}\min(K,N-K)\right )\). Under the constraints \(\dbinom{N}{K}\leq 10^6\), we have \(\min(K,N-K)\leq 11\), so it will finish within the time limit.

import sys

sys.setrecursionlimit(3 * 10**5)

N, K = map(int, input().split())
A = list(map(int, input().split()))

ans = 0


def func(xor, idx, c):
    global ans
    if c == 0:
        ans = max(ans, xor)
        return
    if idx == N:
        return
    func(xor ^ A[idx], idx + 1, c - 1)
    func(xor, idx + 1, c)


if K <= N - K:
    func(0, 0, K)
else:
    all_xor = 0
    for i in range(N):
        all_xor ^= A[i]
    func(all_xor, 0, N - K)

print(ans)

Python has a utility function itertools.combinations that enumerates length-\(K\) subsequences of a given length-\(N\) sequence.

import itertools

N, K = map(int, input().split())
A = list(map(int, input().split()))

ans = 0

if K <= N - K:
    for a in itertools.combinations(A, K):
        xor = 0
        for i in a:
            xor ^= i
        ans = max(ans, xor)
else:
    all_xor = 0
    for i in range(N):
        all_xor ^= A[i]
    for a in itertools.combinations(A, N - K):
        xor = all_xor
        for i in a:
            xor ^= i
        ans = max(ans, xor)

print(ans)

posted:
last update: