公式

I - Count Arrays 解説 by yuto1115

解説

頂点が \(1,2,\dots,N\) であり、各 \(i\ (1\leq i\leq N)\) について \(A_i\rightarrow i\) の辺を張ってできるグラフについて考えます。求めたいのは、このグラフ上の各頂点に \(1\) 以上 \(M\) 以下の整数を書き込む方法であって、\(i\rightarrow j\) の辺があるならば (\(i\) に書き込まれた値)\(\geq\)\(j\) に書き込まれた値)を満たすようなものの数です。このグラフの各連結成分は、\(1\) つのサイクルとそのサイクル上の各頂点を根とした根付き木からなります。(このようなグラフは functional graph や俗になもりグラフなどと呼ばれます。)

グラフの各連結成分は(条件を満たすように値を書き込む上で)互いに干渉しないので、各連結成分について答えを求め、それらをかけ合わせれば良いです。以下、グラフが連結であることを仮定します。

サイクル上に含まれる頂点については \(x_i \leq x_{A_i}\leq x_{A_{A_i}}\leq \dots \leq x_i\) のような循環した大小制約が生じるため、結局同じサイクル上の頂点には全て同じ値を書き込む必要があります。ゆえに、サイクル上の全ての頂点を \(1\) つの頂点にまとめて縮約することができ、このときグラフは \(1\) つの根付き木になります。

この根付き木において、以下のような DP を行います:

  • \(\mathrm{dp}_{i, j}=\)\(i\) を根とする部分木内の各頂点に条件を満たすように値を書き込む方法であって、\(i\) に書き込まれる値が \(j\) であるようなものの数)

この DP は葉から順番に計算できます。頂点 \(i\) の子供の集合を \(C_i\) とすると、\(\mathrm{dp}_{i, j}= \displaystyle \prod_{c\in C_i} \sum_{k=1}^{j} \mathrm{dp}_{c, k}\) です。この式をそのまま用いて愚直に計算すると全体で \(O(N^3)\) になりますが、累積和を用いることで簡単に \(O(N^2)\) に改善することができます。

実装例 (C++) :

#include <bits/stdc++.h>
#include <atcoder/modint>
#include <atcoder/dsu>

using namespace std;
using namespace atcoder;

using mint = modint998244353;

int main() {
    int n, m;
    cin >> n >> m;
    vector<int> a(n);
    vector<vector<int>> ch(n);
    for (int i = 0; i < n; i++) {
        cin >> a[i];
        --a[i];
        ch[a[i]].push_back(i);
    }
    vector<bool> in_cycle(n);
    vector<vector<int>> cycles;
    dsu uf(n);
    for (int i = 0; i < n; i++) {
        if (uf.same(i, a[i])) {
            cycles.emplace_back();
            int now = a[i];
            do {
                in_cycle[now] = true;
                cycles.back().push_back(now);
                now = a[now];
            } while (now != a[i]);
        } else {
            uf.merge(i, a[i]);
        }
    }
    vector dp(n, vector<mint>(m, 1));
    auto dfs = [&](auto &dfs, int i) -> void {
        for (int j: ch[i]) {
            if (in_cycle[j]) continue;
            dfs(dfs, j);
            mint sum = 0;
            for (int k = 0; k < m; k++) {
                sum += dp[j][k];
                dp[i][k] *= sum;
            }
        }
    };
    mint ans = 1;
    for (auto cycle: cycles) {
        vector<mint> prod(m, 1);
        for (int i: cycle) {
            dfs(dfs, i);
            for (int j = 0; j < m; j++) prod[j] *= dp[i][j];
        }
        mint sum = 0;
        for (int j = 0; j < m; j++) sum += prod[j];
        ans *= sum;
    }
    cout << ans.val() << endl;
}

投稿日時:
最終更新: