D - Between Two Arrays Editorial by Nyaan
まずはナイーブな DP を考えてみましょう。 DP テーブル \(dp_{i,j}\) を
- \(dp_{i,j} := \) \(i\) 番目まで要素が確定していて、\(c_i = j\) である場合の数
とおきます。\(i\) 番目の要素に注目すると、
- \(a_i \leq j \leq b_i\) の場合: 1 つ前の要素以上であれば OK
- そうでない場合: NG
となることから、DP の遷移は次のようになります。
(便宜上「C の 0 番目の要素に 0 がある」とみなして \(dp_{0,0} = 1\) と置くことで場合分けの量を減らしています。)
\[dp_{i, j} = \begin{cases} 1 & i = 0 \wedge j = 0 \\ \sum_{0 \leq k \leq j} dp_{i-1,k} & i \geq 1 \wedge a_i \leq j \leq b_i \\ 0 & \mathrm{otherwise} \\ \end{cases} \]
この式は \(i\) を昇順に見ていくことで順番に計算することができます。 また、求める答えは \(N\) 番目まで要素を確定させたときの場合の数、すなわち \(M := \max( \max(A),\max(B))\) として
\[dp_{N,0} + dp_{N,1} + \dots + dp_{N, M}\]
になります。
よって、まずは多項式時間で答えを求めるアルゴリズムを得ることができました。上記の遷移からなる DP の計算量を考えると、
- 更新する必要がある要素の個数:\(\mathrm{O}(NM)\)
- 要素を \(1\) 個計算にかかる計算量: \(\mathrm{O}(M)\)
になるので計算量は \(\mathrm{O}(NM^2)\) とわかります。しかしこれでは TLE してしまうので、どうにかして DP の計算量を落としたいです。
そこで、遷移が \(\sum_{0 \leq k \leq j} dp_{i-1,k}\) という形をしているのに注目して 累積和 を DP で管理するという方法を取ってみましょう。 DP の累積和テーブル \(R_{i,j}\) を
- \(R_{i, j} := \sum_{0 \leq k \leq j} dp_{i,k}\)
とおきます。すると前述の DP は
\[R_{i, j} = \begin{cases} 1 & i = 0 \wedge j = 0 \\ R_{i, j - 1} + R_{i-1,j} & i \geq 1 \wedge a_i \leq j \leq b_i \\ R_{i,j-1} & \mathrm{otherwise} \\ \end{cases} \]
という形で表せるので、要素 1 個あたりの計算量を \(\mathrm{O}(1)\) に減らすことができます。この式を利用して DP することで、この問題を \(\mathrm{O}(NM)\) で解くことができました。
別解として、Fenwick Tree や Segment Tree を利用すれば累積和を \(\mathrm{O} (\log M)\) で取得できることを利用して、元の \(\mathrm{O}(NM^2)\) の DP の計算量を遷移を変えないまま \(\mathrm{O}(NM \log M)\) に改善する解法もあります。
データ構造を使いこなすのが得意な方はこちらの解法の方が分かりやすいかもしれません。(ただし、想定解より計算量が悪くなるため定数倍に注意する必要があります。)
C++, Python による想定解は次の通りです。
- C++
#include <iostream>
#include <vector>
#include "atcoder/modint.hpp"
using namespace std;
using mint = atcoder::modint998244353;
#define rep(i, n) for (int i = 0; i < (n); i++)
int main() {
int N;
cin >> N;
vector<int> A(N), B(N);
for (auto& x : A) cin >> x;
for (auto& x : B) cin >> x;
int MAX = 3000;
vector dp(vector(N + 1, vector(MAX + 1, mint{0})));
dp[0][0] = 1;
rep(i, N + 1) {
rep(j, MAX) dp[i][j + 1] += dp[i][j];
if (i != N) {
for (int j = A[i]; j <= B[i]; j++) dp[i + 1][j] += dp[i][j];
}
}
cout << dp[N][MAX].val() << "\n";
}
- Python (処理系に PyPy3 を選ぶ必要があります)
mod = 998244353
N = int(input())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
M = 3000
dp = [1] * (M + 1)
for i in range(N):
nx = [0] * (M + 1)
for j in range(A[i], B[i] + 1):
nx[j] = dp[j]
dp = nx
for i in range(M):
dp[i + 1] += dp[i]
dp[i + 1] %= mod
print(dp[M])
posted:
last update: