Official

G - Electric Circuit Editorial by en_translator


By linearity of expected value, it is sufficient to find for each \(S \subseteq \lbrace 1,2,\dots,N \rbrace\ (S \neq \emptyset)\) the probability that \(S\) forms a connected component. Here, let us define \(f(S)\) and \(g(S)\) as follows:

  • \(f(S) = g(S) = 0\) if \(S\) contains different numbers of red and blue endpoints
  • Otherwise, let \(m\) be the number of red endpoints in \(S\) and
    • \(f(S)=\) (the number of ways to connect cables, only considering the endpoints in \(S\))\(= m!\)
    • \(g(S)=\) (the number of ways to connect cables, only considering the endpoints in \(S\), so that entire \(S\) becomes connected).

Once we obtain \(g(S)\), it is easy to find the probability that \(S\) forms a connected component. However, it is hard to directly find \(g(S)\) itself; instead, we try to exploit a relation between \(f(S)\) and \(g(S)\) to find \(g(S)\) based on \(f(S)\). Actually, considering partitions into connected components of \(S\) resulting from connecting \(m\) cables, we obtain

\[\displaystyle f(S) = \sum_{\lbrace s_1, s_2,\dots, s_k \rbrace} \prod_{i=1}^{k} g(s_i),\]

where the family of sets \(\lbrace s_1, s_2,\dots, s_k \rbrace\) is a partition of \(S\). Transposing the term with \(k=1\) (where entire \(S\) solely forms a partition), we have

\[\displaystyle g(S) =f(S) - \sum_{\substack{\lbrace s_1, s_2,\dots, s_k \rbrace\\k>1}} \prod_{i=1}^{k} g(s_i).\]

For an arbitrary fixed element \(a\) in \(S\), we can further deform it into

\[\displaystyle g(S) =f(S) - \sum_{\substack{\lbrace s_1, s_2,\dots, s_k \rbrace\\k>1}} \prod_{i=1}^{k} g(s_i).\]

With this equation, one can find \(g(S)\) with an \(O(3^N)\) bit DP (Dynamic Programming) to solve this problem. One can make use of a logarithm of subset convolution to make it even faster.

Sample code (C++):

#include<bits/stdc++.h>
#include<atcoder/modint>

using namespace std;
using namespace atcoder;

using mint = modint998244353;

int main() {
    int n, m;
    cin >> n >> m;
    vector<int> r_cnt(1 << n), b_cnt(1 << n);
    for (int t = 0; t < 2; t++) {
        vector<int> v(n);
        for (int i = 0; i < m; i++) {
            int r;
            cin >> r;
            v[--r]++;
        }
        for (int i = 0; i < (1 << n); i++) {
            for (int j = 0; j < n; j++) {
                if (i >> j & 1) r_cnt[i] += v[j];
            }
        }
        swap(r_cnt, b_cnt);
    }
    vector<mint> fact(m + 1, 1);
    for (int i = 1; i <= m; i++) fact[i] = fact[i - 1] * i;
    vector<mint> f(1 << n), g(1 << n);
    mint ans;
    for (int bit = 1; bit < (1 << n); bit++) {
        if (r_cnt[bit] != b_cnt[bit]) continue;
        f[bit] = g[bit] = fact[r_cnt[bit]];
        for (int sub = (bit - 1) & bit; sub > (bit - sub); sub = (sub - 1) & bit) {
            g[bit] -= g[sub] * f[bit - sub];
        }
        ans += g[bit] * fact[m - r_cnt[bit]];
    }
    for (int i = 1; i <= m; i++) ans /= i;
    cout << ans.val() << endl;
}

posted:
last update: