Official

E - Xor Distances Editorial by penguinman


与えられる木を、適当な頂点 \(x\) を根とした根付き木に変換します。すると、\(\text{dist}(i,j)\) について以下のことが言えます。ここで頂点 \(k\) は頂点 \(i\), \(j\) の最小共通祖先であり、また \(a \oplus b\)\(a \ XOR \ b\) と等価です。

\[ \begin{aligned} \text{dist}(i,j)&=\text{dist}(k,i) \oplus \text{dist}(k,j) \end{aligned} \]

\(a \oplus a=0\) を利用して上記の式を変形します。

\[ \begin{aligned} \text{dist}(i,j)&=\text{dist}(k,i) \oplus \text{dist}(k,j)\\ &=\text{dist}(k,i) \oplus \text{dist}(k,j) \oplus \text{dist}(x,k) \oplus \text{dist}(x,k)\\ &=(\text{dist}(x,k) \oplus \text{dist}(k,i)) \oplus (\text{dist}(x,k) \oplus \text{dist}(k,j))\\ &=\text{dist}(x,i) \oplus \text{dist}(x,j) \end{aligned} \]

よって問題を以下のように言い換えることができます。

\(1 \leq i \lt j \leq N\) を満たす全ての組 \((i,j)\) について \(\text{dist}(x,i) \oplus \text{dist}(x,j)\) を求め、その総和を \((10^9+7)\) で割った余りを求めよ。

まず、前半の各頂点 \(i\) について \(\text{dist}(x,i)\) を求めるパートを考えます。これは BFS などを用いることで、\(O(N)\) で計算することができます。

次に、後半の \(\text{dist}(x,i) \oplus \text{dist}(x,j)\) の総和を求めるパートです。\(XOR\) はビット毎の演算なので、ビット毎に計算することを考えます。

\(XOR\) の定義より、\(\text{dist}(x,i) \oplus \text{dist}(x,j)\)\(k\) ビット目が \(1\) になる条件は、\(\text{dist}(x,i)\)\(k\) ビット目と \(\text{dist}(x,j)\)\(k\) ビット目が異なっていることです。逆にこれらが等しい場合、\(\text{dist}(x,i) \oplus \text{dist}(x,j)\)\(k\) ビット目は \(0\) になります。

よって \(1 \leq i \lt j \leq N\) を満たす全ての組 \((i,j)\) のうち、\(\text{dist}(x,i) \oplus \text{dist}(x,j)\)\(k\) ビット目が \(1\) になるものの個数は \((\text{dist}(x,i)\)\(k\) ビット目が \(0\) であるような頂点 \(i\) の個数\() \times (\text{dist}(x,i)\)\(k\) ビット目が \(1\) であるような頂点 \(i\) の個数\()\) に等しくなります。

故に \(k=1,2,\ldots,60\) について \((\text{dist}(x,i)\)\(k\) ビット目が \(0\) であるような頂点 \(i\) の個数\() \times (\text{dist}(x,i)\)\(k\) ビット目が \(1\) であるような頂点 \(i\) の個数\() \times 2^{k-1}\) を計算し、その総和を取ることで解を得ることができます。

後半部分の計算量・全体での計算量は共に \(O(N \log (\max(A)))\) となります。

解答例 (C++)

#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const ll mod = 1e9+7;

int main(){
    int N; cin >> N;
    vector<vector<ll>> edge(N), weight(N);
    for(int i=1; i<N; i++){
        ll u,v,w; cin >> u >> v >> w;
        edge[--u].push_back(--v);
        edge[v].push_back(u);
        weight[u].push_back(w);
        weight[v].push_back(w);
    }
    vector<ll> dist(N,-1);
    std::queue<int> que;
    que.push(0);
    dist[0] = 0;
    while(!que.empty()){
        int now = que.front(); que.pop();
        for(int i=0; i<edge[now].size(); i++){
            int next = edge[now][i];
            ll sum = dist[now]^weight[now][i];
            if(dist[next] == -1){
                dist[next] = sum;
                que.push(next);
            }
        }
    }
    ll ans = 0;
    for(int i=0; i<60; i++){
        vector<int> cnt(2);
        for(int j=0; j<N; j++) cnt[dist[j]>>i&1]++;
        ans += (1ll<<i)%mod*cnt[0]%mod*cnt[1];
        ans %= mod;
    }
    cout << ans << endl;
}

解答例 (Python)

from collections import deque
N = int(input())
edge = [[]for i in range(N+1)]
weight = [[]for i in range(N+1)]
for i in range(1,N):
    u,v,w = map(int,input().split())
    edge[u].append(v)
    edge[v].append(u)
    weight[u].append(w)
    weight[v].append(w)
dist = [-1]*(N+1)
dist[1] = 0
que = deque([1])
while que:
    now = que.popleft()
    for i in range(len(edge[now])):
        nex = edge[now][i]
        if dist[nex] == -1:
            dist[nex] = dist[now]^weight[now][i]
            que.append(nex)
mod = int(1e9+7)
ans = 0
for i in range(60):
    cnt = [0]*2
    for j in range(N):
        cnt[dist[j+1]>>i&1] += 1
    ans += (1<<i)*cnt[0]*cnt[1]
    ans %= mod
print(ans)

posted:
last update: