Official

H - Count Multiset Editorial by penguinman


サイズ \(k\) の多重集合 \(A\) の個数を数える代わりに、長さ \(k\) の単調非減少正整数列 \(A'=(A'_1,A'_2,\ldots,A'_k)\) の個数を数えることを考えます。

さらにこのままでは見通しが悪いので、差分を取った数列、すなわち以下の漸化式で表される長さ \(k\) の数列 \(B=(B_1,B_2,\ldots,B_k)\) の個数を数えることを考えます。

  • \(B_1=A'_1\)
  • \(B_i=A'_i-A'_{i-1}\ (2 \leq i \leq k)\)

するとこの問題は以下のように言い換えることが可能です。

\(k=1,2,\ldots,N\) について、以下の \(3\) つの条件をすべて満たす非負整数列 \(B=(B_1,B_2,\ldots,B_k)\) の個数を数えよ。

  • \(\sum_{i=1}^{k} B_i \times (k-i-1) = N\)
  • \(0\)\(M\) 個以上連続する箇所が存在しない。すなわち \(B_i=B_{i+1}=\cdots=B_{i+M-1}=0\) となるような整数 \(i\) が存在しない。
  • \(B_1\) は正

ここで \(B\) の添字を反転させます。すると上記の言い換えは以下のようになります。

\(k=1,2,\ldots,N\) について、以下の \(3\) つの条件をすべて満たす非負整数列 \(B=(B_1,B_2,\ldots,B_k)\) の個数を数えよ。

  • \(\sum_{i=1}^{k} B_i \times i = N\)
  • \(0\)\(M\) 個以上連続する箇所が存在しない。すなわち \(B_i=B_{i+1}=\cdots=B_{i+M-1}=0\) となるような整数 \(i\) が存在しない。
  • \(B_k\) は正

これでかなり見通しが良くなりました。

\(0 \leq x,y \leq N\) を満たす整数対 \((x,y)\) について、\(f(x,y)\) を以下のように定めます。

  • 以下の \(3\) つの条件をすべて満たす非負整数列 \(B=(B_1,B_2,\ldots,B_x)\) の個数
    • \(\sum_{i=1}^{x} B_i \times i = y\)
    • \(0\)\(M\) 個以上連続する箇所が存在しない。すなわち \(B_i=B_{i+1}=\cdots=B_{i+M-1}=0\) となるような整数 \(i\) が存在しない。
    • \(B_x\) は正

\(f(x,y)\)\(p \leq x\) および \(q \leq y\) を満たす \(f(p,q)\) の集合を用いて再帰的に計算することが可能であるため、動的計画法による求値が可能であり、愚直に求めた場合の計算量は \(O(N^3M)\) となります。しかしここで累積和等を用いて工夫しながらの計算をすると、その計算量を \(O(N^2)\) まで落とすことが可能です。

よってこの問題を解くことができました。

実装例 (C++)

#include<bits/stdc++.h>
using namespace std;

const int mod = 998244353;

int main(){
    int N,M; cin >> N >> M;
    vector<vector<int>> dp(N+1,vector<int>(N+1));
    dp[0][0] = 1;
    vector<int> sum(N+1);
    sum[0] = 1;
    for(int i=1; i<=N; i++){
        for(int j=i; j<=N; j++){
            dp[i][j] = sum[j-i]+dp[i][j-i];
            dp[i][j] %= mod;
        }
        for(int j=0; j<=N; j++){
            sum[j] += dp[i][j];
            sum[j] %= mod;
            if(M <= i){
                sum[j] -= dp[i-M][j];
                sum[j] %= mod;
            }
        }
    }
    for(int i=1; i<=N; i++){
        int ans = dp[i][N];
        if(ans < 0) ans += mod;
        cout << ans << endl;
    }
}

posted:
last update: