Official

F - Compare Tree Weights Editorial by mechanicalpenciI


頂点 \(1\) を根とした根付き木として考えます。

ここで、頂点 \(1\) から始めて、深さ優先探索の行きがけ順に頂点にあらたに番号 \(f(v)\) を割り当てることを考えます。このとき任意の頂点 \(v\) について次が成り立ちます。

頂点 \(v\) を根とした部分木の各頂点に割り当てられた番号は連続している。すなわち、\(1\leq l\leq r\leq N\) をみたす整数の組 \((l,r)\) が存在し、「頂点に割り当てられた番号が \(l\) 以上 \(r\) 以下である」事と「頂点が部分木に属する」ことが同値となる。

これは、\(v=1\) のときは部分木が \(T\) 自身になるため明らかです。そうでない場合、深さ優先探索で頂点を訪ねる順番を考えると、頂点 \(1\) から始めた探索は、あるタイミングで頂点 \(v\) の親から頂点 \(v\) へ移動し、頂点 \(v\) およびその子孫をすべて訪ねた後に頂点 \(v\) の親へ戻り、二度と頂点 \(v\) を(すなわちその子孫も)訪ねることはないことから従います。

各頂点の重みをその頂点に割り当てられた番号を添字とした配列 \(A=(A_1,A_2,\ldots,A_N)\)上で管理することを考えるとクエリは次のように言い換えられます。

  • 1 x w : \(A_{f(x)}\)\(x\) 増加させる。
  • 2 y : 辺 \(y\) で結ばれている \(2\) 頂点のうち子の方の頂点を \(v\) として、\(v\) を根とした部分木に対応する添字の区間 \([l.r]\) について \(\left\lvert\left(\displaystyle\sum_{i=1}^{l-1}A_i+\displaystyle\sum_{i=r+1}^{N}A_i\right)-\displaystyle\sum_{i=l}^{r}A_i\right\rvert =\left\lvert\displaystyle\sum_{i=1}^{N}A_i-2\displaystyle\sum_{i=l}^{r}A_i\right\rvert\) を求める。

よって、すべての頂点の重みの総和\(\displaystyle\sum_{i=1}^{N}A_i\) はクエリごとに更新すれば良いことを考えると、一点加算および区間和の取得が高速に行えれば良いことがわかります。これは Fenwick tree などのデータ構造を用いてそれぞれ \(O(\log N)\) で行うことができます。
なお、各頂点を根とした部分木に対応する区間は、その頂点に割り当てられた番号が \(l\) であり、その頂点から親へ戻る時に次につける予定の番号から \(1\) を引いたものが \(r\) となるため、一度の深さ優先探索ですべての頂点について求めることができます。
よって、最初の深さ優先探索に \(O(N)\), その後のクエリにそれぞれ \(O(\log N)\) かかるため、全体で \(O(N+Q\log N)\) となり十分高速です。
よって、この問題を解くことができました。

c++ による実装例:

#include <bits/stdc++.h>
#include <atcoder/fenwicktree>

using namespace std;
using namespace atcoder;

#define N 300010

int cur=0;
int l_idx[N];
int r_idx[N];
vector<int>e[N];

void dfs(int k){
	l_idx[k]=cur;
	cur++;
	int sz=e[k].size();
	for(int i=0;i<sz;i++)if(l_idx[e[k][i]]==-1)dfs(e[k][i]);
	r_idx[k]=cur;
	return;
}

int main() {
	int n;
	int u[N];
	int v[N];

	cin>>n;
	for(int i=0;i<n-1;i++){
		cin>>u[i]>>v[i];
		u[i]--,v[i]--;
		e[u[i]].push_back(v[i]);
		e[v[i]].push_back(u[i]);
	}

	for(int i=0;i<n;i++)l_idx[i]=-1,r_idx[i]=-1;
	dfs(0);

	int q;
	int t,x,w,y,z;

	fenwick_tree<int> fw(n);
	for(int i=0;i<n;i++)fw.add(i,1);
	int s=n;

	cin>>q;
	for(int i=0;i<q;i++){
		cin>>t>>x;
		x--;
		if(t==1){
			cin>>w;
			fw.add(l_idx[x],w);
			s+=w;
		}
		else{
			if(l_idx[u[x]]<l_idx[v[x]])y=v[x];
			else y=u[x];
			cout << abs(fw.sum(l_idx[y],r_idx[y])*2-s) <<endl;
		}
	}
	return 0;
}

posted:
last update: