E - チーム分けの整合性 / Consistency of Team Division 解説 by admin
gpt-5.5-high概要
各申告を「2 人のチーム番号の XOR が \(c\) である」という制約として扱います。
固定済みの申告だけで矛盾がないかを Union-Find で調べ、未確定 X の申告に対する良い確定方法の数を連結成分数から求めます。
考察
各人の所属チームを \(0, 1\) のどちらかで表すことにします。
すると、申告 \((u, v, c)\) は次の制約になります。
\[ team_u \oplus team_v = c \]
ここで \(\oplus\) は XOR です。
固定された申告だけを考える
まず、内容が 0 または 1 の申告だけを見ると、これは XOR 制約の集合です。
例えば、
- \(u\) と \(v\) は同じチーム: \(team_u \oplus team_v = 0\)
- \(u\) と \(v\) は異なるチーム: \(team_u \oplus team_v = 1\)
です。
このような制約は、重み付き Union-Find(Parity Union-Find)で矛盾判定できます。
すでに \(u\) と \(v\) の相対関係が決まっている場合、新しい制約と食い違えば矛盾です。
固定済み申告だけで矛盾しているなら、X をどう埋めても全体を満たすことはできないので答えは \(0\) です。
X の個数を単純に数えてはいけない
内容が X の申告が \(x\) 個あるからといって、答えが単純に \(2^x\) になるわけではありません。
例えば三角形の 3 辺がすべて X の場合、3 辺の値の XOR が \(0\) でなければなりません。
したがって \(2^3 = 8\) 通りすべてが良いわけではなく、良い確定方法は \(4\) 通りです。
つまり、X の申告同士にもサイクルを通じた制約が生じます。
良い確定方法の数
固定済み申告だけからなるグラフを考え、その連結成分数を \(C\) とします。
また、X も含めたすべての申告からなるグラフの連結成分数を \(D\) とします。
固定済み申告に矛盾がないとします。
固定済み申告の各連結成分では、内部の相対的なチーム関係は決まっています。
ただし、その連結成分全体をまとめて反転しても、固定済み申告はすべて満たされます。
つまり、固定済み申告の各連結成分ごとに「反転するかしないか」の自由度があります。
一方で、X の申告はこれらの連結成分同士をつなぎます。
すべての申告を含めた 1 つの連結成分内では、全体をまとめて反転しても X に入る値は変わりません。
したがって、すべての申告からなる 1 つの連結成分の中に、固定済み申告の連結成分が \(t\) 個あるとすると、良い確定方法の数は
\[ 2^{t-1} \]
通りです。
これを全連結成分について掛け合わせると、
\[ 2^{C-D} \]
になります。
よって、
- 固定済み申告だけで矛盾があるなら答えは \(0\)
- 矛盾がないなら答えは \(2^{C-D}\)
です。
アルゴリズム
各操作の後に現在の状態に対して答えを計算します。
制約は \(N, Q \leq 3000\) なので、毎回 Union-Find を作り直しても間に合います。
管理する情報
申告ごとに以下を配列で持ちます。
us[i]: 申告 \(i\) の一方の人vs[i]: 申告 \(i\) のもう一方の人cont[i]: 現在の内容012:X
また、すべての申告を無視せずに見たグラフの連結成分数 \(D\) は、通常の Union-Find で管理します。
R 操作では辺の端点は変わらないため、全申告グラフの連結成分数 \(D\) は変化しません。
A 操作で申告が追加されたときだけ、その 2 頂点を union します。
各操作の処理
A u v c
- 申告を配列に追加する
- すべての申告を含む Union-Find で \(u, v\) を union する
R k
申告 \(k\) の現在の内容を見て、申告 \(k+1\) を上書きします。
0なら11なら0XならX
答えの計算
各操作後、次を行います。
- Parity Union-Find を初期化する
- 現在の全申告を順に見る
- 内容が
Xの申告は無視する - 内容が
0または1の申告を XOR 制約として追加する - 矛盾があれば答えは \(0\)
- 矛盾がなければ、固定済み申告グラフの連結成分数を \(C\) として、答えは
\[ 2^{C-D} \]
です。
計算量
現在の申告数を \(M\) とすると、各操作後の再計算に
\[ O(N + M \alpha(N)) \]
かかります。
\(M \leq Q\) なので、全体では
\[ O(Q(N + Q)\alpha(N)) \]
です。
制約 \(N, Q \leq 3000\) では十分高速です。
- 時間計算量: \(O(Q(N + Q)\alpha(N))\)
- 空間計算量: \(O(N + Q)\)
実装のポイント
Parity Union-Find では、各頂点について「親との XOR」を持ちます。
find(x) では、
- 根
- \(x\) から根までの XOR
を返します。
制約
\[ team_u \oplus team_v = c \]
を追加するとき、すでに同じ連結成分にいるなら、現在分かっている XOR と \(c\) が一致するかを確認します。
違っていれば矛盾です。
別の連結成分なら、片方をもう片方に併合し、そのときに根同士の XOR 関係を設定します。
また、R 操作では「申告 \(k\) の現在の内容」を参照する点に注意が必要です。
元々の内容ではなく、これまでの R 操作によって更新された後の値を使います。
ソースコード
import sys
MOD = 998244353
def main():
input = sys.stdin.buffer.readline
N, Q = map(int, input().split())
pow2 = [1] * (N + Q + 5)
for i in range(1, len(pow2)):
pow2[i] = (pow2[i - 1] * 2) % MOD
us = []
vs = []
cont = [] # 0, 1, 2(X)
parent_all = list(range(N))
size_all = [1] * N
comp_all = N
def find_all(x):
while parent_all[x] != x:
parent_all[x] = parent_all[parent_all[x]]
x = parent_all[x]
return x
def union_all(a, b):
nonlocal comp_all
ra = find_all(a)
rb = find_all(b)
if ra == rb:
return
if size_all[ra] < size_all[rb]:
ra, rb = rb, ra
parent_all[rb] = ra
size_all[ra] += size_all[rb]
comp_all -= 1
def solve_current():
parent = list(range(N))
size = [1] * N
parity = [0] * N
comp = N
ok = True
def find(x):
r = x
acc = 0
while parent[r] != r:
acc ^= parity[r]
r = parent[r]
root = r
cur = x
val = acc
while parent[cur] != cur:
p = parent[cur]
d = parity[cur]
parent[cur] = root
parity[cur] = val
val ^= d
cur = p
return root, acc
nonlocal_us = us
nonlocal_vs = vs
nonlocal_cont = cont
for i, c in enumerate(nonlocal_cont):
if c == 2:
continue
a = nonlocal_us[i]
b = nonlocal_vs[i]
ra, xa = find(a)
rb, xb = find(b)
if ra == rb:
if (xa ^ xb) != c:
ok = False
else:
if size[ra] < size[rb]:
ra, rb = rb, ra
xa, xb = xb, xa
parent[rb] = ra
parity[rb] = xa ^ xb ^ c
size[ra] += size[rb]
comp -= 1
if not ok:
return 0
return pow2[comp - comp_all]
out = []
for _ in range(Q):
parts = input().split()
if parts[0] == b'A':
u = int(parts[1]) - 1
v = int(parts[2]) - 1
ch = parts[3]
if ch == b'X':
c = 2
else:
c = ch[0] - 48
us.append(u)
vs.append(v)
cont.append(c)
union_all(u, v)
else:
k = int(parts[1]) - 1
c = cont[k]
if c == 2:
cont[k + 1] = 2
else:
cont[k + 1] = c ^ 1
out.append(str(solve_current()))
sys.stdout.write("\n".join(out))
if __name__ == "__main__":
main()
この解説は gpt-5.5-high によって生成されました。
投稿日時:
最終更新: