G - Electric Circuit Editorial by MMNMM

より高速な解法

Set Power Series を用いることで、この問題を \(\Theta(M+2^N N^2)\) 時間で解くことができます。

\(R,B:2 ^ {N}\mapsto\mathbb N\) を、部品の集合に含まれる赤い/青い端子の個数とします。

まず、この問題の答えは \(H(S)\coloneqq S\) が \(1\) つの連結成分になるような結び方の個数 として

\[\dfrac1{M!}\sum _ {S\in2 ^ {N}}H(S)\]

と求められます。

集合冪級数 \(F,G\) を次のように定めます。

  • \(F(S)\coloneqq S\) に含まれる部品についている端子のみをすべて使って接続する方法の通り数
  • \(G(S)\coloneqq S\) に含まれる端子のみをすべて使って \(S\) を一つの連結成分にする方法の通り数

すると、\(H(S)\) は \((M-R(S))!G(S)\) として求められます。

いま、\(F,G\) に関して \(F=\exp(G)\) が成り立っているので、\(G=\log(F)\) として求めることができます。

\(F\) は、次の式を用いて具体的に求めることができます。

\[F(S)=\left\lbrace\begin{matrix}R(S)!&\ &(R(S)=B(S))\\0&&(R(S)\neq B(S))\end{matrix}\right.\]

あとは、\(\log(F)\) を求め、定数倍して総和を取ればよいです。

#include <iostream>
#include <vector>
#include <array>
#include <bit>
#include <ranges>
#include <atcoder/modint.hpp>

// 階乗とその逆元を前計算する
template<unsigned N, unsigned P>
class fact_inv {
public:
    unsigned fact[N];
    unsigned ifact[N];

    constexpr fact_inv() : fact{}, ifact{} {
        fact[0] = 1;
        for (unsigned i{1}; i < N; ++i)
            fact[i] = static_cast<unsigned long>(fact[i - 1]) * i % P;
        ifact[N - 1] = atcoder::internal::inv_gcd(fact[N - 1], P).second;
        for (unsigned i{N}; --i;)
            ifact[i - 1] = static_cast<unsigned long>(ifact[i]) * i % P;
    }
};
constexpr fact_inv<100001, 998244353> precalc{};

int main() {
    using namespace std;
    using modint = atcoder::static_modint<998244353>;

    constexpr unsigned max_N{17};

    unsigned N, M;
    cin >> N >> M;

    // count_R[i] := i 番目の部品についている赤い端子の個数
    // count_B[i] := i 番目の部品についている青い端子の個数
    vector<unsigned> count_R(N), count_B(N);
    for (unsigned x; [[maybe_unused]] const auto _ : views::iota(0U, M)) {
        cin >> x;
        ++count_R[--x];
    }
    for (unsigned x; [[maybe_unused]] const auto _ : views::iota(0U, M)) {
        cin >> x;
        ++count_B[--x];
    }

    // f[S] := S に含まれる部品についている赤い端子の個数の合計 R が青い端子の個数の合計 B と等しいとき R! x^|S|、そうでないとき 0
    vector<array<modint, max_N + 1>> f(1U << N);
    // 赤い端子の個数の合計と青い端子の個数の合計が等しいような頂点集合と、それぞれの赤い端子の個数の合計
    vector<pair<unsigned, unsigned>> balanced_subgraph;
    for (const auto subgraph : views::iota(0U, 1U << N)) {
        fill(begin(f[subgraph]), end(f[subgraph]), modint{});
        unsigned B{}, R{};
        for (const auto vertex : views::iota(0U, N) | views::filter([subgraph](auto v) { return 1 & (subgraph >> v); })) {
            B += count_B[vertex];
            R += count_R[vertex];
        }
        if (B == R) {
            balanced_subgraph.emplace_back(subgraph, B);
            f[subgraph][popcount(subgraph)] += precalc.fact[B];
        }
    }

    // 高速ゼータ変換
    // Θ(2^N N^2) 時間
    for (const auto vertex : views::iota(0U, N))
        for (const auto subgraph : views::iota(0U, 1U << N) | views::filter([vertex](auto s) { return 1 & (s >> vertex); }))
            for (const auto d : views::iota(0U, N + 1))
                f[subgraph][d] += f[subgraph ^ (1U << vertex)][d];

    // 各点で log をとる
    // Θ(2^N N^2) 時間
    array<modint, max_N + 1> poly_tmp; // 途中の計算結果を入れておく多項式
    for (const auto subgraph : views::iota(0U, 1U << N)) {
        poly_tmp[0] = 0;
        for (const auto d : views::iota(1U, N + 1)) {
            poly_tmp[d] = d * f[subgraph][d];
            for (const auto x : views::iota(1U, d))
                poly_tmp[d] -= x * poly_tmp[x] * f[subgraph][d - x];
            (poly_tmp[d] *= precalc.fact[d - 1]) *= precalc.ifact[d];
        }
        swap(f[subgraph], poly_tmp);
    }

    // 高速メビウス変換
    // Θ(2^N N^2) 時間
    for (const auto vertex : views::iota(0U, N))
        for (const auto subgraph : views::iota(0U, 1U << N) | views::filter([vertex](auto s) { return 1 & (s >> vertex); }))
            for (const auto d : views::iota(0U, N + 1))
                f[subgraph][d] -= f[subgraph ^ (1U << vertex)][d];

    // (M - S(R))! 倍して和を取ると答え
    modint ans{};
    for (const auto &[subgraph, counter] : balanced_subgraph)
        ans += f[subgraph][popcount(subgraph)] * precalc.fact[M - counter];

    cout << (ans * precalc.ifact[M]).val() << endl;

    return 0;
}

posted:
last update: