E - Xor Distances Editorial
by
amitani
頂点\(x\)を根とした根付き木として考えます。公式解説では\(\mathrm{dist}(i, j)=\mathrm{dist}(x, i) \oplus \mathrm{dist}(x, j)\)という変形を利用していますが、これを用いなくても解くことができます。
ビットごとに分けて考えます。それぞれの辺に\(0\)か\(1\)を割り当てると考えて、\(1\)を奇数個含む最短パスの種類を求めれば良いです。右から\(k\)番目のビットについてこの種類数を\(f(k)\)として、最終的な答えは\(\Sigma_{k}f(k) \times 2^{k-1}\)となります。以下\(k\)は省略します。
木の各頂点\(i\)について、その頂点をlowest common ancestorとするようなパスの中で\(1\)を奇数個含むパスの種類\(f_i\)を求めていきます。(これらの和が\(f(k)\)になります)。これは\(i\)の子\(j\)について子を端点とし\(i\)を通らないパスの中で、\(1\)を偶数個通るものと奇数個通るものの個数がわかれば良いです。ただし、簡単のため長さ0のパスも一つ含むことにしましょう。葉は偶数個のパスが一つあることになります。これらをそれぞれ\(g_0(j), g_1(j)\)とし、\(i\)の子の集合を\(C_i\)とします。また、\(b_j\)を\(i\)から\(j\)への辺のビットが立っていれば\(1\), そうでなければ\(0\)とします。
\(i\)を端点とし根から遠ざかる方向のパスであって、異なる子を通るものの組み合わせを考えれば良いです。従って、
\( \begin{aligned} f_i = & \Sigma_{n, m \in C_i, b_n \oplus b_m = 0}(g_0(n) \times g_1(m) + g_1(n) \times g_0(m)) \\ &+ \Sigma_{n, m \in C_i, b_n \oplus b_m = 1}(g_0(n) \times g_0(m) + g_1(n) \times g_1(m)) \\ &+ \Sigma_{n \in C_i, b_n = 0}g_1(n) + \Sigma_{n \in C_i, b_n = 1}g_0(n) \end{aligned} \)
が成り立ちます。これは累積和を用いることで子の数についての線形時間で求めることができます。(組み合わせをすべて調べてしまうと二乗となるので、例えばスターグラフで間に合わなくなってしまいます。) 累積和を求める際に\(i\)を端点とする長さ\(0\)のパスを考えると多少実装が楽になります。
さらに\(g_0(i), g_1(i)\)の漸化式として
\( \begin{aligned} g_0(i) &= 1 + \Sigma_{n \in C_i} g_{1-b_n}(n) \\ g_1(i) &= \Sigma_{n \in C_i} g_{b_n}(n) \end{aligned} \)
が成立します。これらを用いて葉からdpをするか、再帰を用いて求めていくことができます。
注意点としては、\(\mathrm{dist}(x, i)\)を使う手法と異なり、それぞれのbitについて個数を保存する方があるので配列とする必要があります。そのため定数倍遅いはずですが、実行時間制限が3秒なこともあり、PyPy3でも間に合いました。
以下はPythonのACコード例です。
import sys
input = sys.stdin.readline
MOD = 1000000007
M = 64
ones = [1] * M
bits = [1 << m for m in range(M)]
def solve(N: int, U: "List[int]", V: "List[int]", W: "List[int]"):
edges = [[] for _ in range(N+1)]
for u, v, w in zip(U, V, W):
u -= 1
v -= 1
edges[u].append((v, w))
edges[v].append((u, w))
counts = [[1] * (N*M), [0] * (N*M)] # 長さ0のパス1個分は初めから入れておく
ans = 0
to_visit = [(-1, 0, 0)] # parent, current, visited
while to_visit:
parent, current, visited = to_visit.pop()
if not visited:
# Process forward visit
to_visit.append((parent, current, 1))
for v, w in edges[current]:
if v == parent:
continue
to_visit.append((current, v, 0))
continue
# Process return visit
for v, w in edges[current]:
if v == parent:
continue
for m, bit in enumerate(bits):
is_odd_edge = 1 if w & bit else 0
is_even_edge = 1 - is_odd_edge
ans += counts[0][current*M+m] * counts[is_even_edge][v*M+m] * bit
ans %= MOD
ans += counts[1][current*M+m] * counts[is_odd_edge][v*M+m] * bit
ans %= MOD
counts[0][current*M+m] += counts[is_odd_edge][v*M+m]
counts[1][current*M+m] += counts[is_even_edge][v*M+m]
print(ans)
return
def main():
N = int(input()[:-1])
U = [0] * (N-1)
V = [0] * (N-1)
W = [0] * (N-1)
for i in range(N-1):
U[i], V[i], W[i] = map(int, input().split())
solve(N, U, V, W)
if __name__ == '__main__':
main()
posted:
last update: