Official

G - Count Cycles Editorial by en_translator


For any \(i\) and \(j\) with \(i\neq j\), let \(C_{i,j}\) denote the number of edges between vertices \(i\) and \(j\).

First of all, we will process size-\(2\) cycles separately. (It is sufficient to add \(\binom{C_{i,j}}{2}\) to the answer for each \(i<j\).) From now on, we will only think about cycles of size \(3\) or greater.

Fix the vertex with the largest index to some vertex \(s\) \(s\ (3\leq s\leq N)\). We will count the process of traversing a cycle from vertex \(s\) using bit DP (Dynamic Programming). Specifically, define

  • \(\mathrm{dp}[S][i]:=\) the number of paths that start from vertex \(s\), visit the vertices in \(S\) once each, then reach vertex \(i\). (\(s,i\in S\subseteq \{1,2,\dots,s\}\)).

Initialize with \(\mathrm{dp}[\{s\}][s]=1\). The transitions are

\[\mathrm{dp}[S\cup \{j\}][j] \leftarrow \mathrm{dp}[S\cup \{j\}][j]+\mathrm{dp}[S][i]\times C_{i,j}\ (j\not\in S).\]

The final answer (= the number of cycles of length \(3\) or greater where the vertex with the largest index is vertex \(s\)) is represented as

\[\displaystyle \frac{1}{2} \sum_{|S| \geq 3} \sum_{i\in S} dp[S][i]\times C_{i,s}.\]

For any cycle of size \(3\) or greater, there are always exactly two ways to traverse the cycle starting from vertex \(s\); this requires the coefficient \(\frac{1}{2}\) in the expression above.

Perform this for all \(s\). The time complexity is \(O(M) + \displaystyle \sum_{s=3}^{N} O(2^ss^2)=O(M+2^NN^2)\), which is fast enough.

Sample code (C++):

#include <bits/stdc++.h>
#include <atcoder/modint>

using namespace std;

using mint = atcoder::modint998244353;

int main() {
    int n, m;
    cin >> n >> m;
    vector cnt(n, vector<int>(n));
    for (int i = 0; i < m; i++) {
        int u, v;
        cin >> u >> v;
        --u, --v;
        ++cnt[u][v];
        ++cnt[v][u];
    }

    mint ans = 0;
    for (int i = 0; i < n; i++) {
        for (int j = i + 1; j < n; j++) {
            ans += mint(cnt[i][j]) * (cnt[i][j] - 1);
        }
    }
    for (int s = 3; s <= n; s++) {
        vector dp(1 << s, vector<mint>(s));
        dp[1 << (s - 1)][s - 1] = 1;
        for (int bit = 1 << (s - 1); bit < (1 << s); bit++) {
            int ppc = popcount((unsigned) bit);
            for (int i = 0; i < s; i++) {
                if (~bit >> i & 1) continue;
                if (ppc >= 3) {
                    ans += dp[bit][i] * cnt[i][s - 1];
                }
                for (int j = 0; j < s; j++) {
                    if (bit >> j & 1) continue;
                    dp[bit | 1 << j][j] += dp[bit][i] * cnt[i][j];
                }
            }
        }
    }
    ans /= 2;
    cout << ans.val() << endl;
}

posted:
last update: