公式

E - 最も深い共通の上司 / Deepest Common Boss 解説 by kyopro_friends


この問題は LCA(lowest common ancestor) を求める問題そのものです。ダブリングを用いて解くことができます。

DFS/BFSなどにより、各社員 の深さを予め求めます。

観察

まず愚直な解法を考えると、以下の擬似コードのようなアルゴリズムを思いつきます。

func solve(x,y):
  while(depth[x] != depth[y]):
    if(depth[x] > depth[y]):
      x <- P[x]
    else:
      y <- P[y]
  while(x != y):
    x <- P[x]
    y <- P[y]
  return x

このアルゴリズムでは最悪計算量は \(\Theta(N)\) となります。(例えば、全ての \(i\)\(P_i= i-1\) のときに、 \((x,y)=(1,N)\) を聞く)

このアルゴリズムにおいて、1世代ずつ上へ遡っているのが無駄です。まとめて一気に遡ることを考えましょう。

解法

社員 \(i\)\(2^k\) 個上の上司を \(\mathrm{oya}[k][i]\) とします。\(2^k\) 個上の上司とは、\(2^{k-1}\) 個上の上司の \(2^{k-1}\) 個上の上司なので、

\(\mathrm{oya}[k][i]=\begin{cases} P_i & k = 0 のとき\\ \mathrm{oya}[k-1][\mathrm{oya}[k-1][i]] & k > 0 のとき \end{cases}\)

として、\(k\) の昇順にそれぞれ \(O(1)\) で求めることができるので、 \(k \leq \log_2 N\) の範囲全ての \(\mathrm{oya}[k][i]\)\(O(N\log N)\) 時間で求めることができます。

この配列を用いることで、 \(2^k\) 世代まとめて遡ることができるようになったため、 \(O(\log N)\) 時間でクエリに答えることができます。全体の計算量は \(O((N+Q)\log N)\) になります。

実装上は \(P_1=1\) と定めると例外処理の必要がなくなり楽です。

実装例 (C++)

#include<bits/stdc++.h>
using namespace std;

int main(){
  int n, q;
  cin >> n >> q;
  vector<int>p(n);
  for(int i=1; i<n; i++){
    cin >> p[i];
    p[i]--;
  }

  vector<vector<int>>oya(20, vector<int>(n));
  oya[0]=p;
  for(int k=1; k<20; k++){
    for(int i=0; i<n; i++){
      oya[k][i] = oya[k-1][oya[k-1][i]];
    }
  }

  vector<int>depth(n);
  for(int i=1; i<n; i++){
    depth[i] = depth[p[i]] + 1;
  }

  auto solve=[&](int x, int y){
    if(depth[x] > depth[y]){
      swap(x, y);
    }
    int diff = depth[y] - depth[x];
    for(int k=0; k<20; k++){
      if(diff & (1<<k)){
        y = oya[k][y];
      }
    }
    if(x == y){
      return x;
    }
    for(int k=19; k>=0; k--){
      if(oya[k][x] != oya[k][y]){
        x = oya[k][x];
        y = oya[k][y];
      }
    }
    return oya[0][x];
  };

  for(int i=0; i<q; i++){
    int x, y;
    cin >> x >> y;
    x--, y--;
    cout << solve(x, y) + 1 << endl;
  }
}

実装例 (Python)

N, Q = map(int, input().split())
P = list(map(int, input().split()))
P = [0] + [p-1 for p in P]
oya = [[-1]*N for _ in range(20)]
oya[0] = P
for k in range(1,20):
  for v in range(N):
    oya[k][v] = oya[k-1][oya[k-1][v]]

depth = [0] * N
for i in range(1, N):
  depth[i] = depth[P[i]] + 1

def solve(x, y):
  if depth[x] > depth[y]:
    x, y = y, x
  diff = depth[y] - depth[x]
  for k in range(20):
    if diff & (1<<k):
      y = oya[k][y]
  if x == y:
    return x
  for k in range(19, -1, -1):
    if oya[k][x] != oya[k][y]:
      x = oya[k][x]
      y = oya[k][y]
  return oya[0][x]

for _ in range(Q):
  x, y = map(int, input().split())
  x -= 1
  y -= 1
  print(solve(x, y) + 1)

投稿日時:
最終更新: