公式

H - Min of Restricted Sum 解説 by sounansya


まず、整数列 \(X,Y,Z\) から以下のようなグラフを作ります。

  • \(N\) 頂点 \(M\) 辺の無向グラフであって、 \(i=1,2,\ldots,M\) に対して頂点 \(X_i\) と頂点 \(Y_i\) を結ぶラベル \(Z_i\) がついた辺が存在する。

このグラフをもとに考えていきます。

まず、グラフの連結成分毎に問題は独立に解いて良いです。よって、以下ではグラフが連結である場合について考えます。

\(A_1\) の値を \(x\) に固定して考えてみます。すると、辺で繋がっている頂点は片方の値が決まるともう片方の値も決まるので、\(A_1\) の値を決めることで全ての \(A\) の値が決まります。この段階で矛盾が生じれば答えは \(-1\) です。

ここで、 \(A\) の各ビット毎に条件は独立なことに注目します。

全ての \(k\) に対し \(A_1\)\(k\) ビット目は \(0\)\(1\) で、どちらかに固定することで他の \(k\) ビット目も全て決まります。求める \(A\) は要素の総和が最小な良い整数列なので、 \(0\)\(1\) をそれぞれ試し、 \(k\) ビット目が \(1\) となる個数が少ない方(同じである場合はどちらでも良い)にすれば良いです。

これは BFS や DFS などで実装することができます。

以上を適切に実装することでこの問題を解くことができます。計算量は \(O((N+M)\log \max A)\) 時間です。

実装例 (Python3)

from collections import deque

n, m = map(int, input().split())
g = [[] for _ in range(n)]
for _ in range(m):
    x, y, z = map(int, input().split())
    x, y = x - 1, y - 1
    g[x].append((y, z))
    g[y].append((x, z))

visited = [False] * n
val = [-1] * n

def bfs(start):
    dq = deque([start])
    visited[start] = True
    comp = [start]
    while dq:
        v = dq.popleft()
        for u, w in g[v]:
            if not visited[u]:
                visited[u] = True
                val[u] = val[v] ^ w
                comp.append(u)
                dq.append(u)
            else:
                if val[u] != val[v] ^ w:
                    print("-1")
                    exit()
    return comp

ans = [0] * n
for st in range(n):
    if visited[st]:
        continue
    val[st] = 0
    comp = bfs(st)
    for i in range(30):
        cnt = 0
        for j in comp:
            if val[j] & (1 << i):
                cnt += 1
        if cnt < len(comp) - cnt:
            for j in comp:
                if val[j] & (1 << i):
                    ans[j] |= 1 << i
        else:
            for j in comp:
                if not (val[j] & (1 << i)):
                    ans[j] |= 1 << i
print(*ans)

実装例 (Python3)

import pypyjit
pypyjit.set_param('max_unroll_recursion=-1')
import sys

sys.setrecursionlimit(10**7)
n, m = map(int, input().split())
g = [[] for i in range(n)]
for i in range(m):
    x, y, z = map(int, input().split())
    x, y = x - 1, y - 1
    g[x].append((y, z))
    g[y].append((x, z))
visited = [False] * n
val = [-1] * n


q = []


def dfs(v):
    visited[v] = True
    for u, w in g[v]:
        if not visited[u]:
            val[u] = val[v] ^ w
            q.append(u)
            dfs(u)
        else:
            if val[u] != val[v] ^ w:
                print("-1")
                exit()


ans = [0] * n
for st in range(n):
    if visited[st]:
        continue
    val[st] = 0
    q = [st]
    dfs(st)
    for i in range(30):
        cnt = 0
        for j in q:
            if val[j] & (1 << i):
                cnt += 1
        if cnt < len(q) - cnt:
            for j in q:
                if val[j] & (1 << i):
                    ans[j] |= 1 << i
        else:
            for j in q:
                if not val[j] & (1 << i):
                    ans[j] |= 1 << i
print(*ans)

投稿日時:
最終更新: