H - Snuketoon Editorial by Nyaan
まずはナイーブな DP を考えてみましょう。
- \(dp_{i, x} :=\) \(T_i\) 秒後に \(x\) にいるときのダメージの最小値
とおくと、便宜上 \(T_0 = 0\) として次の遷移が求まります。
\[ dp_{i, x} = \begin{cases} 0 & i = 0 \wedge x = 0 \\ \infty & i = 0 \wedge x \neq 0 \\ \displaystyle \min_{ |y - x| \leq T_i - T_{i-1} } dp_{i-1,y} + \max(0, X_i - x)& i \gt 0 \wedge D_i = 0 \\ \displaystyle \min_{ |y - x| \leq T_i - T_{i-1} } dp_{i-1,y} + \max(0, x - X_i) & i \gt 0 \wedge D_i = 1 \end{cases} \]
この DP を愚直に計算すれば \(\mathrm{O}(\max(T)^2)\) で求めることができますが当然 TLE してしまいます。
ここで、 DP テーブルの \(i\) の値を固定して、 \(-T_i \leq x \leq T_i\) の範囲において \(xy\) 座標上に \((x, d_{i,x})\) をプロットして線で結んでみましょう。すると、得られる線は \(\mathrm{O}(i)\) 本の折れ線からなる凸関数になると思います。実際に
- \(f_i(x) :=\) \(-T_i \leq x \leq T_i\) の範囲において \(f_i(x) = dp_{i,x}\) を満たす関数
とおいて上記の DP を \(f(x)\) を使用した形に書き換えると次のようになり、帰納的に \(f_i(x)\) が \(-T_i \leq x \leq T_i\) の範囲で凸関数になるのが確認できると思います。 ( ここで \((X_i-x)_{+}\) は \(\max((X_i-x),0)\) の意味です。)
\[ f_i(x) = \begin{cases} 0 & i = 0 \\ \displaystyle (X_i - x)_{+} + \min_{ |y - x| \leq T_i - T_{i-1} } f_{i-1}(y) & i \gt 0 \wedge D_i = 0 \\ \displaystyle (x - X_i)_{+} + \min_{ |y - x| \leq T_i - T_{i-1} } f_{i-1}(y) & i \gt 0 \wedge D_i = 0 \\ \end{cases} \]
このような DP は変化点を 2 個の優先度付きキュー で管理する Slope Trick と呼ばれる手法で高速に凸関数をシミュレートすることができます。
Slope Trick の解説記事は、日本語文献では maspy さんの記事 が非常に分かりやすく、AtCoder での出題例も多く含まれておりおすすめです。英語話者の方は Codeforces の記事 をご参照ください。
上記の通り Slope Trick に関する解説はすでに非常に素晴らしい文献が存在するのでここでは割愛します。以上よりこの問題を \(\mathrm{O}(N \log N)\) で解くことができました。
Python による実装は以下の通りです。 ( 以下の提出で AC するためには処理系に PyPy3 を選ぶ必要があります。)
import heapq
import sys
L, R, addL, addR, miny = list(), list(), 0, 0, 0
def push_left(x):
heapq.heappush(L, -x + addL)
def push_right(x):
heapq.heappush(R, x - addR)
def top_left():
return -L[0] + addL
def top_right():
return R[0] + addR
def pop_left():
return -heapq.heappop(L) + addL
def pop_right():
return heapq.heappop(R) + addR
def add_xma(a):
global miny
if len(L) != 0:
miny += max(0, top_left() - a)
push_left(a)
push_right(pop_left())
def add_amx(a):
global miny
if len(R) != 0:
miny += max(0, a - top_right())
push_right(a)
push_left(pop_right())
it = map(int, sys.stdin.buffer.read().split())
N = next(it)
L.extend([0] * (N + 10))
R.extend([0] * (N + 10))
t = 0
for T, D, X in zip(it, it, it):
addL -= T - t
addR += T - t
if D == 0:
add_amx(X)
else:
add_xma(X)
t = T
print(miny)
posted:
last update: