Official

E - Xor Distances Editorial by en_translator


Transform the given tree into a rooted tree with an arbitrary vertex \(x\) regarded as the root. Then, \(\text{dist}(i,j)\) satisfies the following property. Here, \(k\) is the lowest common ancestor of \(i\) and \(j\), and \(a \oplus b\) is equivalent to \(a\ XOR\ b\).

\[ \begin{aligned} \text{dist}(i,j)&=\text{dist}(k,i) \oplus \text{dist}(k,j) \end{aligned} \]

Since \(a \oplus a=0\), we can transform the equations in this way:

\[ \begin{aligned} \text{dist}(i,j)&=\text{dist}(k,i) \oplus \text{dist}(k,j)\\ &=\text{dist}(k,i) \oplus \text{dist}(k,j) \oplus \text{dist}(x,k) \oplus \text{dist}(x,k)\\ &=(\text{dist}(x,k) \oplus \text{dist}(k,i)) \oplus (\text{dist}(x,k) \oplus \text{dist}(k,j))\\ &=\text{dist}(x,i) \oplus \text{dist}(x,j) \end{aligned} \]

Therefore, the problem can be rephrased as follows:

For every pair \((i,j)\) such that \(1 \leq i \lt j \leq N\), find \(\text{dist}(x,i) \oplus \text{dist}(x,j)\) and calculate the remainder of their sum divided by \((10^9+7)\).

First, regarding the first half part of finding \(\text{dist}(x,i)\) for each vertex \(i\), they can be computed in an \(O(N)\) with BFS (Breadth-First Search).

Next, we want to find the sum of \(\text{dist}(x,i) \oplus \text{dist}(x,j)\). Since \(XOR\) is an bitwise operator, we will consider computing them bitwise.

By the definition of \(XOR\), the \(k\)-th bit of \(\text{dist}(x,i) \oplus \text{dist}(x,j)\) is \(1\) if the \(k\)-th bit of \(\text{dist}(x,i)\) and the \(k\)-th bit of \(\text{dist}(x,j)\) is different. Conversely, if they are the same, the \(k\)-th bit of \(\text{dist}(x,i) \oplus \text{dist}(x,j)\) is \(0\).

Therefore, for every pair \((i,j)\) such that \(1 \leq i \lt j \leq N\), one can find the number of pairs whose \(k\)-th bit is \(1\) as \((\) the number of vertices \(i\) such that the \(k\)-th bit of \(\text{dist}(x,i)\) is \(0) \times \)( the number of vertices \(i\) such that the \(k\)-th bit of \(\text{dist}(x,i)\) is \(1)\).

Hence, we can obtain the answer by calculating, for each \(k=1,2,\ldots,60\), \((\) the number of vertices \(i\) such that the \(k\)-th bit of \(\text{dist}(x,i)\) is \(0) \times \)( the number of vertices \(i\) such that the \(k\)-th bit of \(\text{dist}(x,i)\) is \(1)\), and then summing them up.

The space complexity of the latter part and the total complexity is both \(O(N \log (\max(A)))\).

Sample Code (C++)

#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const ll mod = 1e9+7;

int main(){
    int N; cin >> N;
    vector<vector<ll>> edge(N), weight(N);
    for(int i=1; i<N; i++){
        ll u,v,w; cin >> u >> v >> w;
        edge[--u].push_back(--v);
        edge[v].push_back(u);
        weight[u].push_back(w);
        weight[v].push_back(w);
    }
    vector<ll> dist(N,-1);
    std::queue<int> que;
    que.push(0);
    dist[0] = 0;
    while(!que.empty()){
        int now = que.front(); que.pop();
        for(int i=0; i<edge[now].size(); i++){
            int next = edge[now][i];
            ll sum = dist[now]^weight[now][i];
            if(dist[next] == -1){
                dist[next] = sum;
                que.push(next);
            }
        }
    }
    ll ans = 0;
    for(int i=0; i<60; i++){
        vector<int> cnt(2);
        for(int j=0; j<N; j++) cnt[dist[j]>>i&1]++;
        ans += (1ll<<i)%mod*cnt[0]%mod*cnt[1];
        ans %= mod;
    }
    cout << ans << endl;
}

Sample Code (Python)

from collections import deque
N = int(input())
edge = [[]for i in range(N+1)]
weight = [[]for i in range(N+1)]
for i in range(1,N):
    u,v,w = map(int,input().split())
    edge[u].append(v)
    edge[v].append(u)
    weight[u].append(w)
    weight[v].append(w)
dist = [-1]*(N+1)
dist[1] = 0
que = deque([1])
while que:
    now = que.popleft()
    for i in range(len(edge[now])):
        nex = edge[now][i]
        if dist[nex] == -1:
            dist[nex] = dist[now]^weight[now][i]
            que.append(nex)
mod = int(1e9+7)
ans = 0
for i in range(60):
    cnt = [0]*2
    for j in range(N):
        cnt[dist[j+1]>>i&1] += 1
    ans += (1<<i)*cnt[0]*cnt[1]
    ans %= mod
print(ans)

posted:
last update: