公式

H - Maximize XOR 解説 by toam


特殊な制約 \(\dbinom{N}{K}\leq 10^6\) に注目します.\(A\) から異なる \(K\) 個の項を選ぶ方法が \(10^6\) 通り以下であることが保証されているので,探索方法を工夫することでそのような選び方を全探索できます.

全探索をする方法として,以下のようなコードを書くことができそうです.

def func(x: list[int], i: int):
    if len(x) == K:
        # 長さが K であるようなインデックスの列 x がここで列挙される
        return
    if i == N:
        return
    func(x, i + 1)
    func(x + [i], i + 1)

func([], 0)

ただし,この方法では \(K\) が大きいとき実行時間に間に合うことができません.なぜなら,関数 func が呼び出される回数が \(\displaystyle \sum_{i=0}^K \dbinom{N}{i}\) 回程度となり,\(\dbinom{N}{K}\) の値は小さくても \(\dbinom{N}{\lfloor N/2\rfloor}\) の値が大きくなってしまうことがあるからです.

例えば,\((N,K)=(100,98)\) のとき,\(\dbinom{N}{K}=4950\lt 10^6\) で制約を満たしますが,\(\dbinom{100}{50}\) は約 \(10^{29}\) と非常に大きくなってしまいます.

これを回避するためには,\(K\) が大きいときには選ばれる \(K\) 個を考える代わりに選ばれない \(N-K\) 個を考えることで 実行時間に間に合わせることができます.具体的には,以下のような場合分けによって答えを計算することができます.

  • \(K\leq N-K\) のとき:上のコードのように愚直に全探索をすればよいです.
  • \(K\gt N-K\) のとき:あらかじめすべての要素の XOR を計算しておき,選ばない \(N-K\) 個を全探索することで \(K\) 個選んだときの総 XOR を高速に求めることができます.

計算量は \(O\left(\dbinom{N}{K}\min(K,N-K)\right )\) です.\(\dbinom{N}{K}\leq 10^6\) という制約では \(\min(K,N-K)\leq 11\) が成り立つのでこれで間に合います.

import sys

sys.setrecursionlimit(3 * 10**5)

N, K = map(int, input().split())
A = list(map(int, input().split()))

ans = 0


def func(xor, idx, c):
    global ans
    if c == 0:
        ans = max(ans, xor)
        return
    if idx == N:
        return
    func(xor ^ A[idx], idx + 1, c - 1)
    func(xor, idx + 1, c)


if K <= N - K:
    func(0, 0, K)
else:
    all_xor = 0
    for i in range(N):
        all_xor ^= A[i]
    func(all_xor, 0, N - K)

print(ans)

Python であれば,長さ \(N\) の列 \(A\) のうち長さが \(K\) である部分列の列挙が itertools.combinations を用いることで容易にできます.

import itertools

N, K = map(int, input().split())
A = list(map(int, input().split()))

ans = 0

if K <= N - K:
    for a in itertools.combinations(A, K):
        xor = 0
        for i in a:
            xor ^= i
        ans = max(ans, xor)
else:
    all_xor = 0
    for i in range(N):
        all_xor ^= A[i]
    for a in itertools.combinations(A, N - K):
        xor = all_xor
        for i in a:
            xor ^= i
        ans = max(ans, xor)

print(ans)

投稿日時:
最終更新: