E - Xor Distances Editorial by penguinman
与えられる木を、適当な頂点 \(x\) を根とした根付き木に変換します。すると、\(\text{dist}(i,j)\) について以下のことが言えます。ここで頂点 \(k\) は頂点 \(i\), \(j\) の最小共通祖先であり、また \(a \oplus b\) は \(a \ XOR \ b\) と等価です。
\[ \begin{aligned} \text{dist}(i,j)&=\text{dist}(k,i) \oplus \text{dist}(k,j) \end{aligned} \]
\(a \oplus a=0\) を利用して上記の式を変形します。
\[ \begin{aligned} \text{dist}(i,j)&=\text{dist}(k,i) \oplus \text{dist}(k,j)\\ &=\text{dist}(k,i) \oplus \text{dist}(k,j) \oplus \text{dist}(x,k) \oplus \text{dist}(x,k)\\ &=(\text{dist}(x,k) \oplus \text{dist}(k,i)) \oplus (\text{dist}(x,k) \oplus \text{dist}(k,j))\\ &=\text{dist}(x,i) \oplus \text{dist}(x,j) \end{aligned} \]
よって問題を以下のように言い換えることができます。
\(1 \leq i \lt j \leq N\) を満たす全ての組 \((i,j)\) について \(\text{dist}(x,i) \oplus \text{dist}(x,j)\) を求め、その総和を \((10^9+7)\) で割った余りを求めよ。
まず、前半の各頂点 \(i\) について \(\text{dist}(x,i)\) を求めるパートを考えます。これは BFS などを用いることで、\(O(N)\) で計算することができます。
次に、後半の \(\text{dist}(x,i) \oplus \text{dist}(x,j)\) の総和を求めるパートです。\(XOR\) はビット毎の演算なので、ビット毎に計算することを考えます。
\(XOR\) の定義より、\(\text{dist}(x,i) \oplus \text{dist}(x,j)\) の \(k\) ビット目が \(1\) になる条件は、\(\text{dist}(x,i)\) の \(k\) ビット目と \(\text{dist}(x,j)\) の \(k\) ビット目が異なっていることです。逆にこれらが等しい場合、\(\text{dist}(x,i) \oplus \text{dist}(x,j)\) の \(k\) ビット目は \(0\) になります。
よって \(1 \leq i \lt j \leq N\) を満たす全ての組 \((i,j)\) のうち、\(\text{dist}(x,i) \oplus \text{dist}(x,j)\) の \(k\) ビット目が \(1\) になるものの個数は \((\text{dist}(x,i)\) の \(k\) ビット目が \(0\) であるような頂点 \(i\) の個数\() \times (\text{dist}(x,i)\) の \(k\) ビット目が \(1\) であるような頂点 \(i\) の個数\()\) に等しくなります。
故に \(k=1,2,\ldots,60\) について \((\text{dist}(x,i)\) の \(k\) ビット目が \(0\) であるような頂点 \(i\) の個数\() \times (\text{dist}(x,i)\) の \(k\) ビット目が \(1\) であるような頂点 \(i\) の個数\() \times 2^{k-1}\) を計算し、その総和を取ることで解を得ることができます。
後半部分の計算量・全体での計算量は共に \(O(N \log (\max(A)))\) となります。
解答例 (C++)
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const ll mod = 1e9+7;
int main(){
int N; cin >> N;
vector<vector<ll>> edge(N), weight(N);
for(int i=1; i<N; i++){
ll u,v,w; cin >> u >> v >> w;
edge[--u].push_back(--v);
edge[v].push_back(u);
weight[u].push_back(w);
weight[v].push_back(w);
}
vector<ll> dist(N,-1);
std::queue<int> que;
que.push(0);
dist[0] = 0;
while(!que.empty()){
int now = que.front(); que.pop();
for(int i=0; i<edge[now].size(); i++){
int next = edge[now][i];
ll sum = dist[now]^weight[now][i];
if(dist[next] == -1){
dist[next] = sum;
que.push(next);
}
}
}
ll ans = 0;
for(int i=0; i<60; i++){
vector<int> cnt(2);
for(int j=0; j<N; j++) cnt[dist[j]>>i&1]++;
ans += (1ll<<i)%mod*cnt[0]%mod*cnt[1];
ans %= mod;
}
cout << ans << endl;
}
解答例 (Python)
from collections import deque
N = int(input())
edge = [[]for i in range(N+1)]
weight = [[]for i in range(N+1)]
for i in range(1,N):
u,v,w = map(int,input().split())
edge[u].append(v)
edge[v].append(u)
weight[u].append(w)
weight[v].append(w)
dist = [-1]*(N+1)
dist[1] = 0
que = deque([1])
while que:
now = que.popleft()
for i in range(len(edge[now])):
nex = edge[now][i]
if dist[nex] == -1:
dist[nex] = dist[now]^weight[now][i]
que.append(nex)
mod = int(1e9+7)
ans = 0
for i in range(60):
cnt = [0]*2
for j in range(N):
cnt[dist[j+1]>>i&1] += 1
ans += (1<<i)*cnt[0]*cnt[1]
ans %= mod
print(ans)
posted:
last update: