G - Range Knapsack Query Editorial
by
yuto1115
解説
列に対する分割統治を用います。
\(f(l,r)\) を、区間 \([l,r]\) に包含されるクエリに対する答えを全て求める関数とします。特に、\(f(1,N)\) を呼ぶと全てのクエリに対して答えを求めることができます。
\(m=\frac{l+r}{2}\) とおきます。区間 \([l, m-1], [m, r]\) のいずれかに包含されているクエリについては \(f(l, m-1), f(m, r)\) を呼ぶことで再帰的に処理すればよいため、ここで求めたいのは \(l\leq L_i < m \leq R_i \leq r\) なるクエリ全てに対する答えです。
以下のように DP を定義します。
- \(\mathrm{dp}_l(i, j)\ (l\leq i < m)=\)(アイテム \(i\dots m-1\) から重みの総和が \(j\) 以下となるようにいくつか選ぶとき、価値の総和の最大値)
- \(\mathrm{dp}_r(i, j)\ (m\leq i\leq r)=\)(アイテム \(m\dots i\) から重みの総和が \(j\) 以下となるようにいくつか選ぶとき、価値の総和の最大値)
これらの値は、\(\mathrm{dp}_l\) については \(i\) の降順に、 \(\mathrm{dp}_r\) については \(i\) の昇順に、合計 \(O((r-l)K)\)(\(K=\max C_i\))の計算量で求めることができます。
これを求めておけば、\(l\leq L_i < m < R_i \leq r\) なるクエリ \(i\) に対する答えは \(\max_{0\leq j\leq C_i} (\mathrm{dp}_l(L_i, j)+\mathrm{dp}_r(R_i,C_i-j))\) となるため \(O(K)\) で求まります。
全体の計算量は \(O(K(Q + N\log N))\) です。(下記の実装例では \(O(K(Q + N\log N)+Q\log N)\) となっていますが、\(Q\log N\) の項を取り除くことも可能です。)
実装例 (C++) :
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int K = 510;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin >> n;
vector<int> w(n), v(n);
for (int i = 0; i < n; i++) {
cin >> w[i] >> v[i];
}
int q;
cin >> q;
vector<tuple<int, int, int>> query;
for (int i = 0; i < q; i++) {
int l, r, c;
cin >> l >> r >> c;
--l;
query.emplace_back(l, r, c);
}
vector<ll> ans(q);
auto upd_dp = [&](const vector<ll> &dp_pre, vector<ll> &dp_nx, int i) {
for (int j = 0; j < K; j++) {
dp_nx[j] = dp_pre[j];
if (j >= w[i]) {
dp_nx[j] = max(dp_nx[j], dp_pre[j - w[i]] + v[i]);
}
}
};
vector dp(n + 1, vector<ll>(K));
auto f = [&](auto &f, int l, int r, const vector<int> &qid) -> void {
if (l + 1 == r) {
for (int i: qid) {
auto [nl, nr, nc] = query[i];
assert(nl == l and nr == r);
ans[i] = (nc >= w[l] ? v[l] : 0);
}
return;
}
int m = (l + r) / 2;
fill(dp[m].begin(), dp[m].end(), 0);
for (int i = m - 1; i >= l; i--) upd_dp(dp[i + 1], dp[i], i);
for (int i = m + 1; i <= r; i++) upd_dp(dp[i - 1], dp[i], i - 1);
vector<int> qid_l, qid_r;
for (int i: qid) {
auto [nl, nr, nc] = query[i];
if (nr <= m) {
qid_l.push_back(i);
} else if (nl >= m) {
qid_r.push_back(i);
} else {
for (int j = 0; j <= nc; j++) {
ans[i] = max(ans[i], dp[nl][j] + dp[nr][nc - j]);
}
}
}
f(f, l, m, qid_l);
f(f, m, r, qid_r);
};
vector<int> qid(q);
iota(qid.begin(), qid.end(), 0);
f(f, 0, n, qid);
for (int i = 0; i < q; i++) {
cout << ans[i] << '\n';
}
}
posted:
last update: