C - Σ Editorial by toam

Python で set を用いる場合の注意点

はじめに

この問題のコードを例えば以下のように書くと,用意されたテストケースに対しては制限時間内に正しい答えを求めることができます.しかし,厳密にはこのコードが適切であるとは言えません.

N, K = map(int, input().split())
A = list(map(int, input().split()))
S = set(A)

ans = K*(K+1)//2
for i in S:
    if i <= K:
        ans -= i

print(ans)

本解説では,Python で set および dict を用いる際に注意すべき点について述べます.

コードが適切ではない理由

Python で実装されている set や dict の計算量は expected \(O(1)\) で,基本的には高速で動作します.しかし,set や dict で実装されている中身を悪用することで,Python に対して意図的ないじわるなケースを作れてしまう (Hack できてしまう) ことが知られています.競プロにおいて,制約を満たしたいかなるケースに対しても正しいコードを書くことが求められるという観点からすると,上のコードは厳密には適切とは言えません.今回は Hack ケースが用意されていませんでしたが,Hack を目的とした意図的なテストケースが用意されていた場合は上のコードでは AC が得られなくなってしまいます.

海外のコンテストサイト Codeforces ではコンテスト中に他人のコードを Hack することができます.そのため,set や dict をそのまま用いるとしばしば Hack されることがあるため特に注意が必要です.

Hack ケース

上のコードで TLE してしまう入力例を紹介します.

Python (CPython 3.11.4) に対する Hack ケース

def generate_python_killer():
    mask = (1 << 30)-1
    A = []
    for i in range(10):
        A.append(mask+2+i)
    x = 6
    for i in range(10**4):
        for j in range(10):
            A.append(x+j)
        x = 5*x+1
        x &= mask
    A += [1]*(N-len(A))
    return A


# N, K = map(int, input().split())
# A = list(map(int, input().split()))
N, K = 2*10**5, 2*10**9
A = generate_python_killer()
S = set(A)
ans = K*(K+1)//2
for i in S:
    if i <= K:
        ans -= i
print(ans)

Python (PyPy 3.10-v7.3.12) に対する Hack ケース

def generate_pypy_killer():
    mask = (1 << 30)-1
    A = [mask+2]
    x = 6
    for i in range(43599):
        A.append(x)
        x = (x*5+1) & mask
    A += [1]*(N-len(A))
    return A


# N, K = map(int, input().split())
# A = list(map(int, input().split()))
N, K = 2*10**5, 2*10**9
A = generate_pypy_killer()
S = set(A)
ans = K*(K+1)//2
for i in S:
    if i <= K:
        ans -= i
print(ans)

コードテストに上の Hack ケースを貼って,実際に TLE することを体感してみると良いかもしれません.

Hack ケースを作るにあたって以下の記事を参考にしました.

Hack ケースを防ぐ方法

競プロでは,意図的な入力による Hack を防ぐ方法として乱数を用いる方法が主流だと筆者は感じます.適当な乱数 \(R\) を用意し,値 \(x\) の代わりに \(R\oplus x\) を用いるようにすることで,いかなる(意図的な)入力に対しても Hack されることを高確率で防ぐことができます.

import random
R = random.randint(1, 1 << 60)

N, K = map(int, input().split())
A = list(map(int, input().split()))
S = set([i ^ R for i in A]) # hack を防ぐため,i の代わりに i ^ R を用いる
ans = K*(K+1)//2
for j in S:
    i = j ^ R  # 元に戻す
    if i <= K:
        ans -= i
print(ans)

他にも,計算量は落ちますが hash を用いずに集合を管理するデータ構造を使うことによっても回避することが可能です.

posted:
last update: