Official

F - Diameter set Editorial by mechanicalpenciI


木を \(T\) とし、\(d(u,v)\)\(T\) 上での頂点 \(u\) , \(v\) 間の距離を表すとします。

適当な頂点を選んで、その点から最も遠い頂点のうちの一つを \(X\) とします。さらに、\(X\) から最も遠い頂点の一つを選び、 \(Y\) とします。このとき、\(d(X,Y)\) は直径に等しい事が知られています。この長さを \(D\) とします。また、\(X\)\(Y\)の間のパスを考え、端点を含むパス上の頂点を順に、 \(V_0(=X)\) , \(V_1\) , \(\ldots\) , \(V_D(=Y)\) とします。

  • \(D\) が奇数のとき

頂点 \(A\) , \(B\)\(A=V_{\frac{D-1}{2}}\) , \(B=V_{\frac{D+1}{2}}\) とします。\(A\)\(B\) を結ぶ辺を切って、\(T\)\(2\) つの部分木 \(T'\) , \(T"\) に分けることを考えます。
このとき、 \(T'\) に属する任意の頂点 \(v'\)について、\(d(v',Y)=d(v',A)+d(A,B)+d(B,Y)=d(v',A)+1+\frac{D-1}{2}\leq D\) が成り立つことから、\(d(A,v')\leq \frac{D-1}{2}\) が成り立ちます。 これより、 \(T'\) に属する \(2\) 頂点 \(u'\) , \(v'\) について、 \(d(u',v')\leq d(A,u')+d(A,v')\leq D-1\) が成り立ちます。同様に\(T"\) に属する \(2\) 頂点 \(u"\) , \(v"\) について、 \(d(u",v")\leq d(B,u")+d(B,v")\leq D-1\) が成り立ちます。
これより、条件をみたすためには \(T'\) , \(T"\) からちょうど一つずつ頂点を選んで赤く塗る必要がある事が分かります。それぞれ \(T'\), \(T"\) に属する頂点 \(v'\) , \(v"\) の間の距離は \(d(v',v")=d(A,v')+d(A,B)+d(B,v")\) と表され、\(d(A,v')\leq \frac{D-1}{2}\), \(d(B,v")\leq \frac{D-1}{2}\) であることから、\(d(v',v")=D\) であることは \(d(A,v')= \frac{D-1}{2}\) かつ \(d(B,v")= \frac{D-1}{2}\) であることと同値である事が分かります。
よって、、\(T'\) に属する点であって、\(A\) から距離 \(\frac{D-1}{2}\) である点の数を \(M_1\) , \(T"\) に属する点であって、\(B\) から距離 \(\frac{D-1}{2}\) である点の数を \(M_2\) とすると、 答えは \(M_1M_2\) 通りとなります。

  • \(D\) が偶数のとき

奇数の時と似た議論で解くことができます。 今度は \(C=V_{\frac{D}{2}}\)を取り、\(C\) と直接結ばれている点を \(A_1\), \(\ldots\) , \(A_K\) とします。ここで、\(C\) および \(C\) とつながっている辺をすべて取り除き、\(T\)\(K\) 個の部分木 \(T_1\) , \(\dots\) , \(T_K\) に分けることを考えます。ただし、\(A_i\) が含まれる部分木を \(T_i\) とおくことにします。
また、\(X\) , \(Y\) が異なる部分木に属する事から、任意の \(i\) について \(X\) , \(Y\) の少なくともどちらかは \(T_i\) に属しておらず、 \(T_i\) に属する頂点 \(v\) について、 \(d(X,v)\) , \(d(Y,v)\) のどちらかは \(\frac{D}{2}+d(C,A_i)+d(A_i,v)\) と書くことができ、これが \(D\) 以下となる事から \(d(A_i,v)\leq \frac{D}{2}-1\) と結論付けられます.
これより, それぞれの \(T_i\) に属する \(2\) 頂点の間の距離は\(D-2\) 以下であり、条件をみたすためには、それぞれの部分木から高々 \(1\) 頂点しか赤く塗れない事が分かります。
相異なる部分木 \(T_i\) , \(T_j\) に属する \(2\) 頂点 \(u\) , \(v\) の間の距離は \(d(A_i,u)+d(A_i,A_j)+d(A_j,v)=d(A_i,u)+d(A_j,v)+2\) と書かれ、\(d(A_i,u)\leq \frac{D}{2}-1\) , \(d(A_j,v)\leq \frac{D}{2}-1\) より、\(d(u,v)=D\) であることは、 \(d(A_i,u)= \frac{D}{2}-1\) かつ \(d(A_j,v)= \frac{D}{2}-1\) と同値である事が分かります。
よって、 \(T_i\) に属し、 \(A_i\) からの距離が \(\frac{D}{2}-1\) である頂点の数を \(M_i\)とすると、それぞれの木で条件をみたす頂点のうちから高々 \(1\) 点を選ぶ方法が \((M_1+1)\cdots (M_K+1)\) 通りあって、そこから \(1\) 頂点も選ばない \(1\) 通りと \(1\) 頂点しか選ばない \((M_1+M_2\cdots +M_K)\) 通りを差し引いた \((M_1+1)\cdots (M_K+1)-(M_1+M_2\cdots +M_K)-1\) 通りが答えとなります。

直径を求めるパート、\(A\),\(B\) または \(C\) を求めるパート、 条件をみたす頂点の数を数えるパート (=\(M_i\) を求めるパート ) および最後に答えを計算するパートのいずれも \(O(N)\) で行う事が出来るため、よって、この問題を解く事ができました。

#include <bits/stdc++.h>
using namespace std;

#define N 200010
#define MOD (ll)998244353
#define ll long long
#define pb push_back
#define rep(i, n) for(ll i = 0; i < n; ++i)

vector<int>e[N];

int d[N][5];
int mx[5];
int mv[5];
int idx, blc, num, cnt;

void dfs(int k, int p) {
	if (p == -1)d[k][idx] = 0;
	else d[k][idx] = d[p][idx] + 1;
	if (mx[idx] < d[k][idx]) {
		mx[idx] = d[k][idx];
		mv[idx] = k;
	}
	if (d[k][idx] == num)cnt++;
	int sz = e[k].size();
	rep(i, sz) {
		if ((e[k][i] != blc) && (d[e[k][i]][idx] < 0))dfs(e[k][i], k);
	}
	return;
}

int main(void) {
	int n;
	int u, v, sz;
	vector<int>a;
	ll x, ans;

	rep(i, 5) {
		rep(j, N)d[j][i] = -1;
		mx[i] = -1;
	}
	blc = -1;

	cin >> n;
	rep(i, n - 1) {
		cin >> u >> v;
		e[u - 1].pb(v - 1);
		e[v - 1].pb(u - 1);
	}

	idx = 0;
	dfs(0, -1);
	idx = 1;
	dfs(mv[0], -1);
	idx = 2;
	dfs(mv[1], -1);

	if (mx[1] % 2 == 1) {
		rep(i, n) {
			if ((d[i][1] == (mx[1] / 2)) && (d[i][2] == (mx[1] / 2) + 1))u = i;
			if ((d[i][2] == (mx[1] / 2)) && (d[i][1] == (mx[1] / 2) + 1))v = i;
		}

		idx = 3;
		num = mx[1] / 2;

		blc = v;
		cnt = 0;
		dfs(u, -1);
		a.pb(cnt);

		blc = u;
		cnt = 0;
		dfs(v, -1);
		a.pb(cnt);
	}
	else {
		rep(i, n) {
			if ((d[i][1] == (mx[1] / 2)) && (d[i][2] == (mx[1] / 2)))u = i;
		}

		idx = 3;
		num = (mx[1] / 2) - 1;
		blc = u;

		sz = e[u].size();
		rep(i, sz) {
			cnt = 0;
			dfs(e[u][i], -1);
			a.pb(cnt);
		}
	}

	sz = a.size();
	ans = 1;
	rep(i, sz) {
		x = a[i] + 1;
		ans = (ans*x) % MOD;
	}
	rep(i, sz) {
		x = a[i];
		ans = (ans + MOD - x) % MOD;
	}
	ans = (ans + MOD - 1) % MOD;

	cout << ans << endl;
	return 0;
}

posted:
last update: