H - Maximize XOR 解説 by en_translator
Notice the unusual constraints \(\dbinom{N}{K}\leq 10^6\). Since it is guaranteed that there are at most \(10^6\) ways to choose \(K\) distinct elements from \(A\), it is possible to enumerate all such choices with an appropriate algorithm.
One approach for the enumeration would be as follows:
def func(x: list[int], i: int):
if len(x) == K:
# Enumerate sequences `x` of length K
return
if i == N:
return
func(x, i + 1)
func(x + [i], i + 1)
func([], 0)
However, this will not finish within the execution time if \(K\) is large, because the function func
is called about \(\displaystyle \sum_{i=0}^K \dbinom{N}{i}\) times, where \(\dbinom{N}{\lfloor N/2\rfloor}\) can be enormous even if \(\dbinom{N}{K}\) is small.
For example, \((N,K)=(100,98)\) satisfies the constraints because \(\dbinom{N}{K}=4950\lt 10^6\), but \(\dbinom{100}{50}\) can be as large as \(10^{29}\).
To avoid this, to finish the process within the time limit when \(K\) is large, one can inspect \((N-K)\) elements that are not chosen, instead of chosen \(K\) elements. Specifically, one can do the following branching:
- If \(K\leq N-K\): we can naively do the exhaustive search.
- If \(K\gt N-K\): first find the total XOR of the elements, and exhaustively search the remaining \((N-K)\) elements. This way, we can find the XOR of the chosen \(K\) elements fast.
The time complexity is \(O\left(\dbinom{N}{K}\min(K,N-K)\right )\). Under the constraints \(\dbinom{N}{K}\leq 10^6\), we have \(\min(K,N-K)\leq 11\), so it will finish within the time limit.
import sys
sys.setrecursionlimit(3 * 10**5)
N, K = map(int, input().split())
A = list(map(int, input().split()))
ans = 0
def func(xor, idx, c):
global ans
if c == 0:
ans = max(ans, xor)
return
if idx == N:
return
func(xor ^ A[idx], idx + 1, c - 1)
func(xor, idx + 1, c)
if K <= N - K:
func(0, 0, K)
else:
all_xor = 0
for i in range(N):
all_xor ^= A[i]
func(all_xor, 0, N - K)
print(ans)
Python has a utility function itertools.combinations
that enumerates length-\(K\) subsequences of a given length-\(N\) sequence.
import itertools
N, K = map(int, input().split())
A = list(map(int, input().split()))
ans = 0
if K <= N - K:
for a in itertools.combinations(A, K):
xor = 0
for i in a:
xor ^= i
ans = max(ans, xor)
else:
all_xor = 0
for i in range(N):
all_xor ^= A[i]
for a in itertools.combinations(A, N - K):
xor = all_xor
for i in a:
xor ^= i
ans = max(ans, xor)
print(ans)
投稿日時:
最終更新: