Official

G - Electric Circuit Editorial by yuto1115

解説

期待値の線形性より、各 \(S \subseteq \lbrace 1,2,\dots,N \rbrace\ (S \neq \emptyset)\) に対して、\(S\)\(1\) つの連結成分となる確率を求めればよいです。ここで、\(f(S)\) および \(g(S)\) を以下のように定義します。

  • \(S\) 内に存在する赤い端点と青い端点の数が異なるならば、\(f(S) = g(S) = 0\)
  • そうでないならば、\(S\) 内に存在する赤い端点の数を \(m\) として、
    • \(f(S)=\)\(S\) 内に存在する端点だけを考えたときのケーブルの繋ぎ方の総数)\(= m!\)
    • \(g(S)=\)\(S\) 内に存在する端点だけを考えたとき、\(S\) 全体が連結になるようなケーブルの繋ぎ方の数)

\(S\)\(1\) つの連結成分となる確率は \(g(S)\) から簡単に求まります。しかし、 \(g(S)\) 自体を直接求めることは困難なので、\(f(S)\)\(g(S)\) の間の何らかの関係式を利用して \(f(S)\) から \(g(S)\) を求めることを考えます。実際、\(m\) 本のケーブルを繋いだとき \(S\) がどのように連結成分に分割されるかを考えることで、

\[\displaystyle f(S) = \sum_{\lbrace s_1, s_2,\dots, s_k \rbrace} \prod_{i=1}^{k} g(s_i)\]

が得られます。ここで、集合族 \(\lbrace s_1, s_2,\dots, s_k \rbrace\)\(S\)分割です。右辺の総和部分のうち \(k=1\) であるもの(すなわち、\(S\) 全体が分割を成しているもの)を移項すると、

\[\displaystyle g(S) =f(S) - \sum_{\substack{\lbrace s_1, s_2,\dots, s_k \rbrace\\k>1}} \prod_{i=1}^{k} g(s_i)\]

となります。\(S\) 内の要素を適当に \(1\) つ固定して \(a\) とおくと、上の式は更に変形でき、

\[\displaystyle g(S) =f(S) - \sum_{\substack{T \subsetneq S\\a \in T}} g(T)f(S\setminus T)\]

が得られます。この式を用いると、\(O(3^N)\) の bit DP によって \(g(S)\) が求まり、本問題を解くことができます。なお、subset convolution の log を用いて更に高速化することも可能です。

実装例 (C++) :

#include<bits/stdc++.h>
#include<atcoder/modint>

using namespace std;
using namespace atcoder;

using mint = modint998244353;

int main() {
    int n, m;
    cin >> n >> m;
    vector<int> r_cnt(1 << n), b_cnt(1 << n);
    for (int t = 0; t < 2; t++) {
        vector<int> v(n);
        for (int i = 0; i < m; i++) {
            int r;
            cin >> r;
            v[--r]++;
        }
        for (int i = 0; i < (1 << n); i++) {
            for (int j = 0; j < n; j++) {
                if (i >> j & 1) r_cnt[i] += v[j];
            }
        }
        swap(r_cnt, b_cnt);
    }
    vector<mint> fact(m + 1, 1);
    for (int i = 1; i <= m; i++) fact[i] = fact[i - 1] * i;
    vector<mint> f(1 << n), g(1 << n);
    mint ans;
    for (int bit = 1; bit < (1 << n); bit++) {
        if (r_cnt[bit] != b_cnt[bit]) continue;
        f[bit] = g[bit] = fact[r_cnt[bit]];
        for (int sub = (bit - 1) & bit; sub > (bit - sub); sub = (sub - 1) & bit) {
            g[bit] -= g[sub] * f[bit - sub];
        }
        ans += g[bit] * fact[m - r_cnt[bit]];
    }
    for (int i = 1; i <= m; i++) ans /= i;
    cout << ans.val() << endl;
}

posted:
last update: