公式
H - Min of Restricted Sum 解説
by
H - Min of Restricted Sum 解説
by
sounansya
まず、整数列 \(X,Y,Z\) から以下のようなグラフを作ります。
- \(N\) 頂点 \(M\) 辺の無向グラフであって、 \(i=1,2,\ldots,M\) に対して頂点 \(X_i\) と頂点 \(Y_i\) を結ぶラベル \(Z_i\) がついた辺が存在する。
このグラフをもとに考えていきます。
まず、グラフの連結成分毎に問題は独立に解いて良いです。よって、以下ではグラフが連結である場合について考えます。
\(A_1\) の値を \(x\) に固定して考えてみます。すると、辺で繋がっている頂点は片方の値が決まるともう片方の値も決まるので、\(A_1\) の値を決めることで全ての \(A\) の値が決まります。この段階で矛盾が生じれば答えは \(-1\) です。
ここで、 \(A\) の各ビット毎に条件は独立なことに注目します。
全ての \(k\) に対し \(A_1\) の \(k\) ビット目は \(0\) か \(1\) で、どちらかに固定することで他の \(k\) ビット目も全て決まります。求める \(A\) は要素の総和が最小な良い整数列なので、 \(0\) と \(1\) をそれぞれ試し、 \(k\) ビット目が \(1\) となる個数が少ない方(同じである場合はどちらでも良い)にすれば良いです。
これは BFS や DFS などで実装することができます。
以上を適切に実装することでこの問題を解くことができます。計算量は \(O((N+M)\log \max A)\) 時間です。
from collections import deque
n, m = map(int, input().split())
g = [[] for _ in range(n)]
for _ in range(m):
x, y, z = map(int, input().split())
x, y = x - 1, y - 1
g[x].append((y, z))
g[y].append((x, z))
visited = [False] * n
val = [-1] * n
def bfs(start):
dq = deque([start])
visited[start] = True
comp = [start]
while dq:
v = dq.popleft()
for u, w in g[v]:
if not visited[u]:
visited[u] = True
val[u] = val[v] ^ w
comp.append(u)
dq.append(u)
else:
if val[u] != val[v] ^ w:
print("-1")
exit()
return comp
ans = [0] * n
for st in range(n):
if visited[st]:
continue
val[st] = 0
comp = bfs(st)
for i in range(30):
cnt = 0
for j in comp:
if val[j] & (1 << i):
cnt += 1
if cnt < len(comp) - cnt:
for j in comp:
if val[j] & (1 << i):
ans[j] |= 1 << i
else:
for j in comp:
if not (val[j] & (1 << i)):
ans[j] |= 1 << i
print(*ans)
import pypyjit
pypyjit.set_param('max_unroll_recursion=-1')
import sys
sys.setrecursionlimit(10**7)
n, m = map(int, input().split())
g = [[] for i in range(n)]
for i in range(m):
x, y, z = map(int, input().split())
x, y = x - 1, y - 1
g[x].append((y, z))
g[y].append((x, z))
visited = [False] * n
val = [-1] * n
q = []
def dfs(v):
visited[v] = True
for u, w in g[v]:
if not visited[u]:
val[u] = val[v] ^ w
q.append(u)
dfs(u)
else:
if val[u] != val[v] ^ w:
print("-1")
exit()
ans = [0] * n
for st in range(n):
if visited[st]:
continue
val[st] = 0
q = [st]
dfs(st)
for i in range(30):
cnt = 0
for j in q:
if val[j] & (1 << i):
cnt += 1
if cnt < len(q) - cnt:
for j in q:
if val[j] & (1 << i):
ans[j] |= 1 << i
else:
for j in q:
if not val[j] & (1 << i):
ans[j] |= 1 << i
print(*ans)
投稿日時:
最終更新:
