公式

G - GCD cost on the tree 解説 by mechanicalpenciI


根付き木として考え、木DPにより解くことを考えます。

また、以下では、木上の頂点 \(u,v\) について、
\(f(u,v)\)\(u=v\) ならば \(1\), \(u\neq v\)ならば \(u\)\(v\) の結ぶ単純パス上の端点を含む頂点数として定義します。
すなわち、\(2\) 頂点間の距離 \(d(u,v)\) に対してつねに \(f(u,v)=d(u,v)+1\) です。
また、\(g(u,v)\)\(u=v\) ならば \(A_u\), \(u\neq v\) ならば \(u\)\(v\) の結ぶ単純パス上おける \(\gcd(A_{p_1},\ldots,A_{p_k})\) の値として定義します。
ここで、相異なる \(2\) 頂点について、 \(C(u,v)=f(u,v)g(u,v)\) です。

最初、各頂点が独立している状態から、深さ優先探索の帰りがけ順にしたがって、根が \(r\) である木 \(T\) に根が \(r'\) である木 \(T'\) を「\(r\) の直接の子に\(r'\) を追加する」という操作でマージするという操作を繰り返すことで元の木を作ることができます。この操作は元の木の辺数 \(|T|-1\) 回だけ行われます。

この過程で出来る各木 \(T\) について、その木に対する答え \(ans(T)=\displaystyle\sum_{i=1}^{|T|-1}\sum_{j=i+1}^{|T|}C(i,j)\) を求め、更新していくことを考えます。 ここで、\(r\) を根とする根付き木 \(T\) について、\(cnt(T,x)\)\(T\) の頂点 \(v\) であって、\(g(r,v)=x\) であるようなものの個数、 \(sum(T,x)\)\(T\) の頂点 \(v\) であって、\(g(r,v)=x\) であるようなものについて、\(f(r,v)\) を足し合わせたものとします。

もし \(T_0\) ( 根は \(r_0\) ), \(T'\) ( 根は \(r'\) )についてそれぞれの値が求まっていたとすると、これをマージした木 \(T\) ( 根は \(r_0\) ) における答え \(ans(T)\) の値は

\[ ans(T)=ans(T_0)+ans(T')+\displaystyle\sum_{x}\sum_{y} \gcd(x,y)\times \left\{ cnt(T_0,x) sum(T',y)+ sum(T_0,x) cnt(T',y) \right\} \]

として求めることができます。第 \(1,2\) 項は \(ans(T_0)+ans(T')\) はそれぞれの木上の頂点同士について \(f(u,v)g(u,v)\) を足し合わせたものであり、第 \(3\) 項については、一方が \(T_0\), 他方が \(T'\) に属するような頂点対において、\(g(u,v)=\gcd(x,y)\) であるようなものについて \(f(u,v)\) の総和を足し合わせていると解釈することができます。
また、\(cnt(T,x)\)\(sum(T,x)\) は、
\(cnt(T,x)=cnt(T_0,x)+\displaystyle\sum_{\gcd(r,y)=x}cnt(T',y)\) および \(sum(T,x)=sum(T_0,x)+\displaystyle\sum_{\gcd(r,y)=x}(sum(T',y)+cnt(T',y))\)
で求めることができます。

なお, \(1\) 頂点 \(v\) のみからなる初期状態においては、\(cnt(T,x)=sum(T,x)=(x=A_v\) ならば \(1\) , それ以外の時 \(0)\) , \(ans(T)=0\) となります。

このようにして、マージ操作を繰り返すことで最終的に元の木における答えを求めることができます。 計算量を考えてみましょう。

\(cnt(T,x)=0\) であるような \(x\) については \(cnt(T,x), sum(T,x)\) の値を持つ必要はありません。そうでないような \(x\)\(A_r\) の約数の個数以下かつ \(|T|\) 個以下であるので、\(10^5\) 以下の正整数の約数の個数の最大値を \(D\) として、\(\min(|T|,D)\) で上から抑えられます。このとき、マージの際の計算回数は\(\min(|T_0|,D)\cdot \min(|T'|,D)\cdot \log(\max(A_i))+\min(|T_0|,D)+\min(|T'|,D)\) であり、\(C_0=\log(\max(A_i))+2\) として、\(C_0\min(|T_0|,D)\min(|T'|,D)\) で抑えることができます。

このことから、\(N\) 頂点の木に対する計算回数の最大値を \(h(N)\) で表すと、

\[ h(N)\leq \displaystyle\max_{1\leq i\leq N-1}(h(i)+h(N-i)+C_0\min(i,D)\min(N-i,D)) \]

が成り立ちます。 また、\(h(1)=1\leq\frac{1}{2}C_0\) と抑えることができます。 このとき、\(N\leq 2D\) の範囲で \(h(N)\leq \frac{1}{2}C_0N^2\) が成り立ち、さらに\(N> 2D\) の範囲で \(h(N)\leq \frac{1}{2}C_0( 3N-2D)D\) が成り立ちます。 よって、計算量は\(O(ND\log(\max(A_i)))\) となります。 \(N=10^5\) において \(D=128\) であるので、これは十分間に合います。よって、この問題を解くことができました。

c++による実装例 :

#include <bits/stdc++.h>
using namespace std;

#define N 100000
#define D 130
#define MOD 998244353
#define rep(i, n) for(int i = 0; i < n; ++i)
#define pil pair<int, long long>

vector<int>e[N];
bool used[N];
int a[N];
long long ans;

int gcd(int x, int y) {
	if (x > y)swap(x, y);
	if (x == 0)return y;
	return gcd(y%x, x);
}

vector<pair<int, pil> > dfs(int k) {
	used[k] = true;
	vector<pair<int, pil> > list1, list2;
	int idx[D];
	int g;
	long long x, y, z;
	list1.push_back({ a[k],{1,1LL} });
	int sz = e[k].size();
	int sz1, sz2, defalut_sz;
	rep(ii, sz) {
		if (!used[e[k][ii]]) {

			list2 = dfs(e[k][ii]);
			sz1 = list1.size();
			sz2 = list2.size();
			defalut_sz = sz1;

			rep(j, sz2) {
				idx[j] = -1;
				rep(i, sz1) {
					if (i < defalut_sz) {
						g = gcd(list1[i].first, list2[j].first);
						if (i == 0)list2[j].first = g;
						x = (list1[i].second.second*list2[j].second.first) % MOD;
						y = (list2[j].second.second*list1[i].second.first) % MOD;
						z = ((x + y)*g) % MOD;
						ans = (ans + z) % MOD;
					}					
					if (list1[i].first == list2[j].first)idx[j] = i;
				}
				if (idx[j] == -1) {
					idx[j] = sz1;
					list1.push_back({ list2[j].first,{0,0LL} });
					sz1++;
				}
			}

			rep(j, sz2) {
				list1[idx[j]].second.first += list2[j].second.first;
				list1[idx[j]].second.second += list2[j].second.first + list2[j].second.second;
				if (list1[idx[j]].second.second >= MOD)list1[idx[j]].second.second -= MOD;
			}

		}
	}
	return list1;
}

int main() {
	int n, x, y;
	cin >> n;
	rep(i, n)cin >> a[i];
	rep(i, n - 1) {
		cin >> x >> y;
		e[x - 1].push_back(y - 1);
		e[y - 1].push_back(x - 1);
	}
	ans = 0;
	rep(i, n)used[i] = false;
	dfs(0);
	cout << ans << endl;
	return 0;
}

投稿日時:
最終更新: