E - 最も深い共通の上司 / Deepest Common Boss 解説
by
kyopro_friends
この問題は LCA(lowest common ancestor) を求める問題そのものです。ダブリングを用いて解くことができます。
DFS/BFSなどにより、各社員 の深さを予め求めます。
観察
まず愚直な解法を考えると、以下の擬似コードのようなアルゴリズムを思いつきます。
func solve(x,y):
while(depth[x] != depth[y]):
if(depth[x] > depth[y]):
x <- P[x]
else:
y <- P[y]
while(x != y):
x <- P[x]
y <- P[y]
return x
このアルゴリズムでは最悪計算量は \(\Theta(N)\) となります。(例えば、全ての \(i\) で \(P_i= i-1\) のときに、 \((x,y)=(1,N)\) を聞く)
このアルゴリズムにおいて、1世代ずつ上へ遡っているのが無駄です。まとめて一気に遡ることを考えましょう。
解法
社員 \(i\) の \(2^k\) 個上の上司を \(\mathrm{oya}[k][i]\) とします。\(2^k\) 個上の上司とは、\(2^{k-1}\) 個上の上司の \(2^{k-1}\) 個上の上司なので、
\(\mathrm{oya}[k][i]=\begin{cases} P_i & k = 0 のとき\\ \mathrm{oya}[k-1][\mathrm{oya}[k-1][i]] & k > 0 のとき \end{cases}\)
として、\(k\) の昇順にそれぞれ \(O(1)\) で求めることができるので、 \(k \leq \log_2 N\) の範囲全ての \(\mathrm{oya}[k][i]\) を \(O(N\log N)\) 時間で求めることができます。
この配列を用いることで、 \(2^k\) 世代まとめて遡ることができるようになったため、 \(O(\log N)\) 時間でクエリに答えることができます。全体の計算量は \(O((N+Q)\log N)\) になります。
実装上は \(P_1=1\) と定めると例外処理の必要がなくなり楽です。
実装例 (C++)
#include<bits/stdc++.h>
using namespace std;
int main(){
int n, q;
cin >> n >> q;
vector<int>p(n);
for(int i=1; i<n; i++){
cin >> p[i];
p[i]--;
}
vector<vector<int>>oya(20, vector<int>(n));
oya[0]=p;
for(int k=1; k<20; k++){
for(int i=0; i<n; i++){
oya[k][i] = oya[k-1][oya[k-1][i]];
}
}
vector<int>depth(n);
for(int i=1; i<n; i++){
depth[i] = depth[p[i]] + 1;
}
auto solve=[&](int x, int y){
if(depth[x] > depth[y]){
swap(x, y);
}
int diff = depth[y] - depth[x];
for(int k=0; k<20; k++){
if(diff & (1<<k)){
y = oya[k][y];
}
}
if(x == y){
return x;
}
for(int k=19; k>=0; k--){
if(oya[k][x] != oya[k][y]){
x = oya[k][x];
y = oya[k][y];
}
}
return oya[0][x];
};
for(int i=0; i<q; i++){
int x, y;
cin >> x >> y;
x--, y--;
cout << solve(x, y) + 1 << endl;
}
}
実装例 (Python)
N, Q = map(int, input().split())
P = list(map(int, input().split()))
P = [0] + [p-1 for p in P]
oya = [[-1]*N for _ in range(20)]
oya[0] = P
for k in range(1,20):
for v in range(N):
oya[k][v] = oya[k-1][oya[k-1][v]]
depth = [0] * N
for i in range(1, N):
depth[i] = depth[P[i]] + 1
def solve(x, y):
if depth[x] > depth[y]:
x, y = y, x
diff = depth[y] - depth[x]
for k in range(20):
if diff & (1<<k):
y = oya[k][y]
if x == y:
return x
for k in range(19, -1, -1):
if oya[k][x] != oya[k][y]:
x = oya[k][x]
y = oya[k][y]
return oya[0][x]
for _ in range(Q):
x, y = map(int, input().split())
x -= 1
y -= 1
print(solve(x, y) + 1)
投稿日時:
最終更新:
