E - アルバイトのシフト管理 / Part-Time Job Shift Management 解説
by
kyopro_friends
この問題はDPで解くことができます。
この手の問題は「最後に○○したとき」を状態に持つDPを考えるのが定石です。今回は
\(\mathrm{dp}_0[i]=\) \(i\) 日目が連休最終日ときの、 \(i\) 日目までの報酬の最大値
\(\mathrm{dp}_1[i]=\) \(i\) 日目が連勤最終日ときの、 \(i\) 日目までの報酬の最大値
と定めます。もらう DP を考えます。
\(S_i=\sum_{j=1}^{i}A_i\) とおきます。
\(\mathrm{dp}_0[i]\) を求めるには直前の連勤の最終日がいつであるか、\(\mathrm{dp}_1[i]\) を求めるには直前の連休の最終日がいつであるかを考えることで、
\[\begin{aligned}\mathrm{dp}_0[i] &=\max_{i-M+1\leq j \leq i-1}\mathrm{dp}_1[j]\\ \mathrm{dp}_1[i]&=\max_{i-K+1\leq j \leq i-1}\left(\mathrm{dp}_0[j]+S_i-S_j\right)\end{aligned}\]
という漸化式を得ることができます。
\(\mathrm{dp}_1[i]\) の右辺に登場する max は \(S_j\) の影響があるためこのままでは高速に計算するのは困難です。そこで \(S_j\) も \(\mathrm{dp}_0[j]\) に押し付けてしまうことにします。
\(\mathrm{dp}_0'[i]=\mathrm{dp}_0[i]-S_i\) とおくと
\[\mathrm{dp}'_0[i]=\left(\max_{i-M+1\leq j \leq i-1}\mathrm{dp}_1[j]\right)-S_i\]
\[\mathrm{dp}_1[i]=\left(\max_{i-K+1\leq j \leq i-1}\mathrm{dp}'_0[j]\right)+S_i\]
となり、スライド最大値やセグメントツリーなどを用いることで、 \(O(N)\) や \(O(N\log N)\) で DP テーブルを埋めることができます。求める答えは \(\max(\mathrm{dp}_0[N], \mathrm{dp}_1[N])\) です。
実装例 (C++)
#include<bits/stdc++.h>
#include<atcoder/segtree>
using namespace std;
long long op(long long x, long long y){return max(x, y);}
long long e(){return (long long)-1e18;}
int main(){
int n, k, m;
cin >> n >> k >> m;
vector<int>a(n);
for(int i=0; i<n; i++) cin >> a[i];
vector<long long>s(n+1);
for(int i=0; i<n; i++){
s[i+1] = s[i] + a[i];
}
atcoder::segtree<long long, op, e> seg0(n+1), seg1(n+1);
seg0.set(0, 0);
seg1.set(0, 0);
for(int i=1; i<=n; i++){
seg0.set(i, seg1.prod(max(0, i-m+1), i) - s[i]);
seg1.set(i, seg0.prod(max(0, i-k+1), i) + s[i]);
}
cout << max(seg0.get(n) + s[n], seg1.get(n)) << endl;
}
実装例 (Python)
from atcoder.segtree import SegTree
N, K, M = map(int, input().split())
A = list(map(int, input().split()))
S = [0]
for a in A:
S.append(S[-1] + a)
seg0 = SegTree(max, -10**18, N+1)
seg1 = SegTree(max, -10**18, N+1)
seg0.set(0, 0)
seg1.set(0, 0)
for i in range(1, N+1):
seg0.set(i, seg1.prod(max(0, i-M+1), i) - S[i])
seg1.set(i, seg0.prod(max(0, i-K+1), i) + S[i])
print(max(seg0.get(N) + S[N], seg1.get(N)))
投稿日時:
最終更新:
