G - Many MST Editorial by en_translator
Consider that the edge weights are between \(0\) (inclusive) and \(M\) (exclusive). We will finally add \((N-1)\times M^{N(N-1)/2}\) to obtain the answer.
For a connected graph \(G\) with edge weights between \(0\) (inclusive) and \(M\) (exclusive), the total weight of the edges in its minimum spanning tree is represented as \(\displaystyle \sum_{k=1}^{M} c(G_k)-M\), where \(G_k\) is the (unweighted) graph consisting of edges of \(G\) with weight less than \(k\), and \(c(G_k)\) is the number of connected components in \(G_k\).
Let \(S_N\) be the set of all \(N\)-vertex complete graph with weights between \(0\) (inclusive) and \(M\) (exclusive), and \(C(G_k)\) be the set of the connected components of \(G_k\). Then the sought answer is
\[\displaystyle \sum_{G\in S_N}\sum_{k=1}^{M}c(G_k)-M=-M\times M^{N(N-1)/2}+\sum_{k=1}^{M}\sum_{G\in S_N}c(G_k)=-M\times M^{N(N-1)/2}+\sum_{k=1}^{M}\sum_{G\in S_N}\sum_{H\in C(G_k)}1.\]
(We used \(c(G_k)=|C(G_k)|\).) Hereinafter, we will derive how to find \(\displaystyle \sum_{G\in S_N}\sum_{H\in C(G_k)}1\) for each \(k\).
(Intuition so far: by decomposing contributions to the answer according to edge weights, the problem is boiled down to a combinatorics problem on unweighted connected graph, which is easier to tackle.)
For a fixed \(H\) (which is a connected subgraph of the \(N\)-vertex complete graph), the count of graphs \(G\in S_N\) with \(H\in C(G_k)\) is
\[(M-k)^{|V(H)|(|V(H)|-1)/2-m(H)}\times k^{m(H)}\times (M-k)^{|V(H)|(N-|V(H)|)}\times M^{(N-|V(H)|)(N-|V(H)|-1)/2},\]
where \(V(H)\) is the vertex set of \(H\) and \(m(H)\) is the number of edges in \(H\). This is derived by considering the weight of each edge in \(G\). If its both ends are in \(V(H)\), the weight is less than \(k\) if it is contained in \(H\) and not less than \(k\) otherwise. If the edge lies between \(V(H)\) and \(\{1,2,\ldots,N\}\setminus V(H)\), its weight is not less than \(k\). If its both ends are in \(\{1,2,\ldots,N\}\setminus V(H)\), the weight can be anything.
By the “contribution to the sum” trick (focusing on \(H\) to deform the expression):
\[\displaystyle \sum_{G\in S_N}\sum_{H\in C(G_k)}1=\sum_{H}(M-k)^{|V(H)|(|V(H)|-1)/2-m(H)}\times k^{m(H)}\times (M-k)^{|V(H)|(N-|V(H)|)}\times M^{(N-|V(H)|)(N-|V(H)|-1)/2}.\]
The number of graphs \(H\) with \(s\) vertices is \(\dbinom{N}{s}\) times the number of graphs \(H\) with vertex set \(\{1,2,\ldots,s\}\). Denoting by \(f(s)\) the sum of \((M-k)^{|V(H)|(|V(H)|-1)/2-m(H)}\times k^{m(H)}\) over all graphs \(H\) with vertex set \(\{1,2,\ldots,s\}\), we arrive at
\[\displaystyle \sum_{H}(M-k)^{|V(H)|(|V(H)|-1)/2-m(H)}\times k^{m(H)}\times (M-k)^{|V(H)|(N-|V(H)|)}\times M^{(N-|V(H)|)(N-|V(H)|-1)/2}=\sum_{s=1}^N \binom{N}{s}f(s)\times (M-k)^{s(N-s)}\times M^{(N-s)(N-s-1)/2}\]
After all, it is sufficient to find \(f(s)\) for each \(s=1,2,\ldots,N\). This can be done in the same manner as ABC213G: considering how connected graphs were counted there, we obtain \(f(s)=M^{s(s-1)/2}-\displaystyle\sum_{i=1}^{s-1}f(i)\dbinom{s-1}{i-1}(M-k)^{i(s-i)}M^{(s-i)(s-i-1)/2}\). Based on this equation, we can find \(f(1),f(2),\ldots,f(N)\) in a total of \(O(N^2)\) time.
By performing this for each \(k=1,2,\ldots,M\), the answer can be found in a total of \(O(N^2M)\) time.
#include <bits/stdc++.h>
using namespace std;
#include <atcoder/modint>
using namespace atcoder;
using mint = modint998244353;
mint binom[510][510];
mint POW[510][150000];
int main() {
for (int i = 0; i < 510; i++) {
binom[i][0] = 1;
binom[i][i] = 1;
for (int j = 1; j < i; j++) binom[i][j] = binom[i - 1][j - 1] + binom[i - 1][j];
POW[i][0] = 1;
for (int j = 1; j < 150000; j++) POW[i][j] = POW[i][j - 1] * i;
}
int n, m;
cin >> n >> m;
mint ans = (n - 1 - m) * POW[m][n * (n - 1) / 2];
for (int k = 1; k <= m; k++) {
vector<mint> f(n + 1, 0);
for (int s = 1; s <= n; s++) {
f[s] = POW[m][s * (s - 1) / 2];
for (int i = 1; i < s; i++) f[s] -= f[i] * binom[s - 1][i - 1] * POW[m - k][i * (s - i)] * POW[m][(s - i) * (s - i - 1) / 2];
ans += binom[n][s] * f[s] * POW[m - k][s * (n - s)] * POW[m][(n - s) * (n - s - 1) / 2];
}
}
cout << ans.val() << endl;
}
posted:
last update: