F - Manhattan Cafe Editorial by Nyaan
第 \(1\) 成分から順に決めていく DP を考えてみましょう。
このとき、DP を動作させるうえで必要な情報は (\(p\) と \(r\) の距離), (\(q\) と \(r\) の距離) の 2 つの情報なので、その 2 つを添え字として持って次のように状態を定義します。
- \(dp_{t, i, j} :=\) (第 \(t\) 軸までの \(r\) の値が確定していて, \(\sum_{n=1}^t |p_n - r_n| = i, \sum_{n=1}^t |q_n - r_n| = j\) であるような場合の数)
このとき、問題の答えは下の式で表されます。
\[\sum_{i=0}^D \sum_{j=0}^D dp_{N, i, j}\]
また、DP の初期状態や遷移は以下の C++ 風のコードで表せます。
#define rep(i, n) for (int i = 0; i < (int)(n); i++)
constexpr int MAX = 1001;
atcoder::modint998244353 dp[MAX][MAX], nx[MAX][MAX];
dp[0][0] = 1;
rep(t, N) {
rep(i, D + 1) rep(j, D + 1) nx[i][j] = 0;
ll pt = p[t], qt = q[t];
for (int rt = -MAX * 2; rt <= MAX * 2; rt++) {
int di = abs(pt - rt);
int dj = abs(qt - rt);
if (min(di, dj) > D) continue;
for (int i = 0, ie = D - di; i <= ie; i++) {
for (int j = 0, je = D - dj; j <= je; j++) {
nx[i + di][j + dj] += dp[i][j];
}
}
}
swap(dp, nx);
}
以上の考察でこの問題を解けた…と書きたいところですが、時間計算量の問題があります。上の DP は計算量 \(\mathrm{O}(ND^3)\) なので制約下では TLE してしまいます。
そこで、 DP の遷移がある特長を持つことに注目してみましょう。上のコード中の \((di, dj)\) の組は \(s = \vert p_t - q_t \vert\) として次の 3 種類に分類できるのが確認できます。
- \((s,0), (s-1, 1), (s-2, 2), \dots, (1, s-1), (0, s)\)
- \((s+1, 1), (s+2, 2), (s+3, 3), \dots\)
- \((1, s+1), (2, s+2), (3, s+3), \dots\)
それぞれの集合の要素をグリッドに図示するといずれも斜めに一直線に並びます。よって、あらかじめ DP テーブルの値斜め方向に累積和を取っておくことで、\((i, j)\) が定まった時の遷移を \(\mathrm{O}(D)\) から \(\mathrm{O}(1)\) に落とすことができて、全体の計算量が \(\mathrm{O}(ND^3)\) から \(\mathrm{O}(N D^2)\) に改善されます。(詳しくは実装例を参考にしてください。)よってこの問題を十分高速に解くことができました。
- 実装例(PyPy)
mod = 998244353
N, D = map(int, input().split())
p = list(map(int, input().split()))
q = list(map(int, input().split()))
dp = [[0 for _ in range(D + 1)] for _ in range(D + 1)]
dp[0][0] = 1
for n in range(N):
pn, qn = p[n], q[n]
s = abs(pn - qn)
nxt = [[0 for _ in range(D + 1)] for _ in range(D + 1)]
dp2 = [[0 for _ in range(D + 1)] for _ in range(D + 1)]
for i in range(D + 1):
for j in range(D + 1):
dp2[i][j] = dp[i][j]
if i != 0 and j != D:
dp2[i][j] += dp2[i - 1][j + 1]
dp2[i][j] %= mod
for i in range(D + 1):
for j in range(D + 1):
si = i
sj = j - s
if sj < 0:
si += sj
sj = 0
if 0 <= si <= D and 0 <= sj <= D:
nxt[i][j] += dp2[si][sj]
nxt[i][j] %= mod
ti = i - (s + 1)
tj = j + 1
if 0 <= ti <= D and 0 <= tj <= D:
nxt[i][j] -= dp2[ti][tj]
nxt[i][j] %= mod
dp3 = [[0 for _ in range(D + 1)]for _ in range(D + 1)]
for i in range(D + 1):
for j in range(D + 1):
dp3[i][j] = dp[i][j]
if i != 0 and j != 0:
dp3[i][j] += dp3[i - 1][j - 1]
dp3[i][j] %= mod
if i + 1 <= D and j + s + 1 <= D:
nxt[i + 1][j + s + 1] += dp3[i][j]
nxt[i + 1][j + s + 1] %= mod
if i + s + 1 <= D and j + 1 <= D:
nxt[i + s + 1][j + 1] += dp3[i][j]
nxt[i + s + 1][j + 1] %= mod
dp = nxt
print(sum((sum(v) for v in dp)) % mod)
なお、この実装例は可読性のために定数倍が悪い実装になっています。TL に十分余裕を持って実装したい場合は、
dp = [[0 for _ in range(D + 1)] for _ in range(D + 1)]
の部分を
dp = [0] * (D + 1) * (D + 1)
# dp[i * (D + 1) + j] で (i, j) 成分を参照
というように書き換えれば \(2\) 倍から \(3\) 倍程度高速に動作すると考えられるので、そのように実装することをお勧めします。
- TL が低速な言語に厳しめに設定されているのは、C++ による愚直なアルゴリズム (時間計算量 \(\mathrm{O}(\frac{N D^3}{s})\) ) がラフな実装でも 10 sec 前後で動作して、この解法を TLE させることを目的としたものです。(ここでの \(s\) は
pragma
構文などを用いた自動ベクトル化による並列計算で同時に処理できるデータの個数で、 \(s=8\) または \(16\))- 鋭い方は「ナイーブな DP を落とすには \(D\) を大きくすればいいじゃないか」と考えるかもしれませんが( 2 つの解法の TL の比は \(D/s\) に依存するため)、TLE してほしい解法に \(\mathrm{O}(N^2 D^2)\) 解法も存在するため、これ以上 \(D\) を大きくすると相対的に \(N\) が小さくなって \(\mathrm{O}(N^2 D^2)\) 解法が通りやすくなってしまう問題があります。
posted:
last update: