Official

H - Count Multiset Editorial by en_translator


Instead of counting the number of multisets \(A\) of size \(k\), we will count the number of monotonically non-decreasing integer sequences \(A'=(A'_1,A'_2,\ldots,A'_k)\) of length \(k\).

Additionally, for more clarity, we will consider counting the number of progressions of differences, that is, sequences \(B=(B_1,B_2,\ldots,B_k)\) of length \(k\) defined by the following recurrence relations.

  • \(B_1=A'_1\)
  • \(B_i=A'_i-A'_{i-1}\ (2 \leq i \leq k)\)

Then the problem can be rephrased as follows.

For each \(k=1,2,\ldots,N\), count the number of sequences of non-negative integers satisfying the following three conditions:

  • \(\sum_{i=1}^{k} B_i \times (k-i-1) = N\)
  • There are no \(M\) or more consecutive \(0\)’s; i.e. there does not exist an integer \(i\) such that \(B_i=B_{i+1}=\cdots=B_{i+M-1}=0\)
  • \(B_1\) is positive

Now we reverse the indices of \(B\). Then the restatement above is equivalent to what follows.

For each \(k=1,2,\ldots,N\), count the number of sequences of non-negative integers satisfying the following three conditions:

  • \(\sum_{i=1}^{k} B_i \times i = N\)
  • There are no \(M\) or more consecutive \(0\)’s; i.e. there does not exist an integer \(i\) such that \(B_i=B_{i+1}=\cdots=B_{i+M-1}=0\)
  • \(B_k\) is positive

Now this is easy to handle with.

For a pair of integer \((x,y)\) satisfying \(0 \leq x,y \leq N\), define \(f(x,y)\) as follows.

  • The number of sequences of non-negative integers satisfying the following three conditions:
    • \(\sum_{i=1}^{k} B_i \times (k-i-1) = N\)
    • There are no \(M\) or more consecutive \(0\)’s; i.e. there does not exist an integer \(i\) such that \(B_i=B_{i+1}=\cdots=B_{i+M-1}=0\)
    • \(B_x\) is positive

\(f(x,y)\) can be computed from the set of \(f(p,q)\) for \(p \leq x\) and \(q \leq y\), which we can find the values with DP (Dynamic Programming), requiring a complexity of \(O(N^3M)\) when processed naively. However this can be improved with cumulative sums, resulting to a complexity of \(O(N^2)\).

Therefore, the problem has been solved.

Sample code (C++)

#include<bits/stdc++.h>
using namespace std;

const int mod = 998244353;

int main(){
    int N,M; cin >> N >> M;
    vector<vector<int>> dp(N+1,vector<int>(N+1));
    dp[0][0] = 1;
    vector<int> sum(N+1);
    sum[0] = 1;
    for(int i=1; i<=N; i++){
        for(int j=i; j<=N; j++){
            dp[i][j] = sum[j-i]+dp[i][j-i];
            dp[i][j] %= mod;
        }
        for(int j=0; j<=N; j++){
            sum[j] += dp[i][j];
            sum[j] %= mod;
            if(M <= i){
                sum[j] -= dp[i-M][j];
                sum[j] %= mod;
            }
        }
    }
    for(int i=1; i<=N; i++){
        int ans = dp[i][N];
        if(ans < 0) ans += mod;
        cout << ans << endl;
    }
}

posted:
last update: