G - Electric Circuit Editorial
by
yuto1115
解説
期待値の線形性より、各 \(S \subseteq \lbrace 1,2,\dots,N \rbrace\ (S \neq \emptyset)\) に対して、\(S\) が \(1\) つの連結成分となる確率を求めればよいです。ここで、\(f(S)\) および \(g(S)\) を以下のように定義します。
- \(S\) 内に存在する赤い端点と青い端点の数が異なるならば、\(f(S) = g(S) = 0\)
- そうでないならば、\(S\) 内に存在する赤い端点の数を \(m\) として、
- \(f(S)=\)(\(S\) 内に存在する端点だけを考えたときのケーブルの繋ぎ方の総数)\(= m!\)
- \(g(S)=\)(\(S\) 内に存在する端点だけを考えたとき、\(S\) 全体が連結になるようなケーブルの繋ぎ方の数)
\(S\) が \(1\) つの連結成分となる確率は \(g(S)\) から簡単に求まります。しかし、 \(g(S)\) 自体を直接求めることは困難なので、\(f(S)\) と \(g(S)\) の間の何らかの関係式を利用して \(f(S)\) から \(g(S)\) を求めることを考えます。実際、\(m\) 本のケーブルを繋いだとき \(S\) がどのように連結成分に分割されるかを考えることで、
\[\displaystyle f(S) = \sum_{\lbrace s_1, s_2,\dots, s_k \rbrace} \prod_{i=1}^{k} g(s_i)\]
が得られます。ここで、集合族 \(\lbrace s_1, s_2,\dots, s_k \rbrace\) は \(S\) の分割です。右辺の総和部分のうち \(k=1\) であるもの(すなわち、\(S\) 全体が分割を成しているもの)を移項すると、
\[\displaystyle g(S) =f(S) - \sum_{\substack{\lbrace s_1, s_2,\dots, s_k \rbrace\\k>1}} \prod_{i=1}^{k} g(s_i)\]
となります。\(S\) 内の要素を適当に \(1\) つ固定して \(a\) とおくと、上の式は更に変形でき、
\[\displaystyle g(S) =f(S) - \sum_{\substack{T \subsetneq S\\a \in T}} g(T)f(S\setminus T)\]
が得られます。この式を用いると、\(O(3^N)\) の bit DP によって \(g(S)\) が求まり、本問題を解くことができます。なお、subset convolution の log を用いて更に高速化することも可能です。
実装例 (C++) :
#include<bits/stdc++.h>
#include<atcoder/modint>
using namespace std;
using namespace atcoder;
using mint = modint998244353;
int main() {
int n, m;
cin >> n >> m;
vector<int> r_cnt(1 << n), b_cnt(1 << n);
for (int t = 0; t < 2; t++) {
vector<int> v(n);
for (int i = 0; i < m; i++) {
int r;
cin >> r;
v[--r]++;
}
for (int i = 0; i < (1 << n); i++) {
for (int j = 0; j < n; j++) {
if (i >> j & 1) r_cnt[i] += v[j];
}
}
swap(r_cnt, b_cnt);
}
vector<mint> fact(m + 1, 1);
for (int i = 1; i <= m; i++) fact[i] = fact[i - 1] * i;
vector<mint> f(1 << n), g(1 << n);
mint ans;
for (int bit = 1; bit < (1 << n); bit++) {
if (r_cnt[bit] != b_cnt[bit]) continue;
f[bit] = g[bit] = fact[r_cnt[bit]];
for (int sub = (bit - 1) & bit; sub > (bit - sub); sub = (sub - 1) & bit) {
g[bit] -= g[sub] * f[bit - sub];
}
ans += g[bit] * fact[m - r_cnt[bit]];
}
for (int i = 1; i <= m; i++) ans /= i;
cout << ans.val() << endl;
}
posted:
last update: