Official

I - Count Arrays Editorial by en_translator


Consider a graph with vertices \(1,2,\dots,N\), and edges \(A_i\rightarrow i\) for each \(i\ (1\leq i\leq N)\). What we want to find is the number of ways to write integers between \(1\) and \(M\) on each vertex, so that (number on \(i\)) \(\geq\) (number on \(j\)) holds for every edge \(i\rightarrow j\). Each connected component is formed by a cycle and rooted trees rooted at vertices in the cycle. (Such a graph is called a functional graph, or informally Naomri graph (named after a Japanese cartoonist).)

Each connected component are independent (when writing numbers), so we can find the answer for each of them and find their product. We now assume that the graph is connected.

The vertices in the cycle have cyclic order constraints: \(x_i \leq x_{A_i}\leq x_{A_{A_i}}\leq \dots \leq x_i\), so after all the numbers on the cycle must be equal. Thus, we can contract all the vertices in the cycle into one, making the graph form a rooted tree.

On this rooted tree, perform the following DP (Dynamic Programming):

  • \(\mathrm{dp}_{i, j}=\) (the number of ways to write integers onto each vertex in the subtree rooted at \(i\), with vertex \(i\) having number \(j\) written on it).

This DP can be evaluated from the leaves: \(\mathrm{dp}_{i, j}= \displaystyle \prod_{c\in C_i} \sum_{k=1}^{j} \mathrm{dp}_{c, k}\), where \(C_i\) is the set of children of vertex \(i\). Evaluating it naively costs a total of \(O(N^3)\) time, while it can be easily optimized to \(O(N^2)\) using cumulative sums.

Sample code (C++):

#include <bits/stdc++.h>
#include <atcoder/modint>
#include <atcoder/dsu>

using namespace std;
using namespace atcoder;

using mint = modint998244353;

int main() {
    int n, m;
    cin >> n >> m;
    vector<int> a(n);
    vector<vector<int>> ch(n);
    for (int i = 0; i < n; i++) {
        cin >> a[i];
        --a[i];
        ch[a[i]].push_back(i);
    }
    vector<bool> in_cycle(n);
    vector<vector<int>> cycles;
    dsu uf(n);
    for (int i = 0; i < n; i++) {
        if (uf.same(i, a[i])) {
            cycles.emplace_back();
            int now = a[i];
            do {
                in_cycle[now] = true;
                cycles.back().push_back(now);
                now = a[now];
            } while (now != a[i]);
        } else {
            uf.merge(i, a[i]);
        }
    }
    vector dp(n, vector<mint>(m, 1));
    auto dfs = [&](auto &dfs, int i) -> void {
        for (int j: ch[i]) {
            if (in_cycle[j]) continue;
            dfs(dfs, j);
            mint sum = 0;
            for (int k = 0; k < m; k++) {
                sum += dp[j][k];
                dp[i][k] *= sum;
            }
        }
    };
    mint ans = 1;
    for (auto cycle: cycles) {
        vector<mint> prod(m, 1);
        for (int i: cycle) {
            dfs(dfs, i);
            for (int j = 0; j < m; j++) prod[j] *= dp[i][j];
        }
        mint sum = 0;
        for (int j = 0; j < m; j++) sum += prod[j];
        ans *= sum;
    }
    cout << ans.val() << endl;
}

posted:
last update: