Official

G - Range Knapsack Query Editorial by yuto1115

解説

列に対する分割統治を用います。

\(f(l,r)\) を、区間 \([l,r]\) に包含されるクエリに対する答えを全て求める関数とします。特に、\(f(1,N)\) を呼ぶと全てのクエリに対して答えを求めることができます。

\(m=\frac{l+r}{2}\) とおきます。区間 \([l, m-1], [m, r]\) のいずれかに包含されているクエリについては \(f(l, m-1), f(m, r)\) を呼ぶことで再帰的に処理すればよいため、ここで求めたいのは \(l\leq L_i < m \leq R_i \leq r\) なるクエリ全てに対する答えです。

以下のように DP を定義します。

  • \(\mathrm{dp}_l(i, j)\ (l\leq i < m)=\)(アイテム \(i\dots m-1\) から重みの総和が \(j\) 以下となるようにいくつか選ぶとき、価値の総和の最大値)
  • \(\mathrm{dp}_r(i, j)\ (m\leq i\leq r)=\)(アイテム \(m\dots i\) から重みの総和が \(j\) 以下となるようにいくつか選ぶとき、価値の総和の最大値)

これらの値は、\(\mathrm{dp}_l\) については \(i\) の降順に、 \(\mathrm{dp}_r\) については \(i\) の昇順に、合計 \(O((r-l)K)\)\(K=\max C_i\))の計算量で求めることができます。

これを求めておけば、\(l\leq L_i < m < R_i \leq r\) なるクエリ \(i\) に対する答えは \(\max_{0\leq j\leq C_i} (\mathrm{dp}_l(L_i, j)+\mathrm{dp}_r(R_i,C_i-j))\) となるため \(O(K)\) で求まります。

全体の計算量は \(O(K(Q + N\log N))\) です。(下記の実装例では \(O(K(Q + N\log N)+Q\log N)\) となっていますが、\(Q\log N\) の項を取り除くことも可能です。)

実装例 (C++) :

#include <bits/stdc++.h>

using namespace std;

using ll = long long;

const int K = 510;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n;
    cin >> n;
    vector<int> w(n), v(n);
    for (int i = 0; i < n; i++) {
        cin >> w[i] >> v[i];
    }
    int q;
    cin >> q;
    vector<tuple<int, int, int>> query;
    for (int i = 0; i < q; i++) {
        int l, r, c;
        cin >> l >> r >> c;
        --l;
        query.emplace_back(l, r, c);
    }
    vector<ll> ans(q);
    auto upd_dp = [&](const vector<ll> &dp_pre, vector<ll> &dp_nx, int i) {
        for (int j = 0; j < K; j++) {
            dp_nx[j] = dp_pre[j];
            if (j >= w[i]) {
                dp_nx[j] = max(dp_nx[j], dp_pre[j - w[i]] + v[i]);
            }
        }
    };
    vector dp(n + 1, vector<ll>(K));
    auto f = [&](auto &f, int l, int r, const vector<int> &qid) -> void {
        if (l + 1 == r) {
            for (int i: qid) {
                auto [nl, nr, nc] = query[i];
                assert(nl == l and nr == r);
                ans[i] = (nc >= w[l] ? v[l] : 0);
            }
            return;
        }
        int m = (l + r) / 2;
        fill(dp[m].begin(), dp[m].end(), 0);
        for (int i = m - 1; i >= l; i--) upd_dp(dp[i + 1], dp[i], i);
        for (int i = m + 1; i <= r; i++) upd_dp(dp[i - 1], dp[i], i - 1);
        vector<int> qid_l, qid_r;
        for (int i: qid) {
            auto [nl, nr, nc] = query[i];
            if (nr <= m) {
                qid_l.push_back(i);
            } else if (nl >= m) {
                qid_r.push_back(i);
            } else {
                for (int j = 0; j <= nc; j++) {
                    ans[i] = max(ans[i], dp[nl][j] + dp[nr][nc - j]);
                }
            }
        }
        f(f, l, m, qid_l);
        f(f, m, r, qid_r);
    };
    vector<int> qid(q);
    iota(qid.begin(), qid.end(), 0);
    f(f, 0, n, qid);
    for (int i = 0; i < q; i++) {
        cout << ans[i] << '\n';
    }
}

posted:
last update: