F - Many Increasing Problems 解説 by potato167


公式解説 にもあるように、全ての \(A\)\(1\leq s \leq M\) に対して、以下の答えの総和を求めれば良いです。

  • \(A\) の要素のうち \(s\) 以下のものを \(0\)\(s+1\) 以上のものを \(1\) とした数列 \(B\) に対する Increasing Problem の問題の答えを求めてください。

この問題の答えは以下と同値です。

  • \(x = 0\) で初期化する。
  • \(i = 1, \dots, N\) の順に、\(B_{i}\)\(0\) なら \(x\)\(1\) 減らし、\(1\) なら \(1\) 増やす。
  • \(N+\) (\(x\) の最小値) \(-\sum B\) が答え

全ての \(A, s\) に対する \(\sum B\) の総和は \(\dfrac{M^{N}N(M - 1)}{2}\) であることから、全ての \(A, s\) に対する \(N +\) (\(x\) の最小値) が求まれば良いです。

\(\sum B = a\) である全ての長さ \(N\)\(01\) 数列に対する \(N +\) (\(x\) の最小値) の総和を \(f(a)\) とします。この \(f(a)\) を全ての \(a\) に対して求めます。

\(N +\) (\(x\) の最小値) の総和は ((\(x\) の最小値が \(0\) 以上であるものの場合の数) \(+\) (\(x\) の最小値が \(-1\) 以上であるものの場合の数) \(+\cdots +\) (\(x\) の最小値が \(-N + 1\) 以上であるものの場合の数) ) と同じです。よって、鏡像法によって、\(f(a)\) は以下のように表されます。

\[f(a) = \sum_{i = \max(0, N - 2 a)}^{N - 1} \left( \begin{matrix} N\\ a\\ \end{matrix} \right) - \left( \begin{matrix} N\\ a + i + 1\\ \end{matrix} \right) \]

\(\sum B = a\) であるとき、そのような \(B\) を達成するような \(A, s\) の組み合わせは、\(\sum s ^{N - a} (M - s)^{a}\) 通りあります。よって、以下のコードで答えを求めることができます。

#include <bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for (int i=(int)(a);i<(int)(b);i++)
#include <atcoder/modint>
using mint = atcoder::modint998244353;

int main() {
    // input
    int N, M;
    cin >> N >> M;

    // Binomial
    vector<mint> fact(N + 1, 1), fact_inv(N + 1);
    rep(i, 1, N + 1) fact[i] = fact[i - 1] * i;
    fact_inv[N] = fact[N].inv();
    for (int i = N; i > 0; i--) fact_inv[i - 1] = fact_inv[i] * i;
    auto C = [&](int a, int b) -> mint {
        if (a < b || b < 0) return 0;
        return fact[a] * fact_inv[b] * fact_inv[a - b];
    };

    // init
    mint ans = ((mint)(M)).pow(N) * N * (M - 1) / 2;
    ans *= -1;
    
    // calc f
    vector<mint> f(N + 1);
    rep(a, 0, N + 1){
        rep(i, max(0, N - 2 * a), N){
            f[a] += C(N, a) - C(N, a + i + 1);
        }
    }

    // calc ans
    rep(a, 0, N + 1){
        mint tmp = 0;
        rep(s, 1, M + 1){
            tmp += ((mint)(s)).pow(N - a) * ((mint)(M - s)).pow(a);
        }
        ans += tmp * f[a];
    }

    // output
    cout << ans.val() << "\n";
}

このコードを高速化することで、 AC を得ることができます。

まず、\(f\) を求める場面について、過去の ARC の問題と同様に、差分更新することで、 \(O(N)\) で求まります。\(\sum_{i = \max(0, N - 2 a)}^{N - 1} \left( \begin{matrix} N\\ a + i + 1\\ \end{matrix} \right) \) の部分の差分更新が可能です。

全ての \(a\) に対して \(\sum s ^{N - a} (M - s)^{a}\) を求める場面は、以下の式が成り立つことを用います。

\[\sum s ^{N - a} (M - s)^{a} = [x^{a}]\left(\sum\dfrac{s^{N}}{1-\frac{M-s}{s}x}\right)\]

右辺の多項式は、分割統治を用いて \(O(M\log(M)^{2} + N\log(N))\) で求まるため、全ての \(a\) に対する \(\sum s ^{N - a} (M - s)^{a}\) も同様の計算量で求まるため、 AC することができます。

c++ による実装例 218ms

投稿日時:
最終更新: