Official

G - Lightweight Knapsack Editorial by en_translator


Among the items with the same weight, it is optimal to choose them in descending order of their values. Thus, what only matters is how many items of weights \(1\), \(2\), and \(3\) we pick. Let \(F_i\ (i=1,2,3)\) be the number of items of weight \(F_i\). Also, let \(R_i := F_i \bmod \frac{6}{i}\); namely,

  • \(R_1:= F_1 \bmod 6\),
  • \(R_2:= F_2 \bmod 3\),
  • \(R_3:= F_3 \bmod 2\).

Let us fix the values \(R_1,R_2\) and \(R_3\). Then for \(i=1,2,3\),

  • The top-valued \(R_i\) items with weight \(i\) are always picked.
  • By definition of \(R_i\), the number of the other items to pick is always a multiple of \(\frac{6}{i}\).
  • Thus, we may group every \(\frac{6}{i}\) items and “merge” the items within each group. For example, for \(i=2\), if the remaining items have values \(5,5,4,3,2,1,1\), we may group every three items (and disregard the remaining items) to regard that there are “weight-\(6\), value-\(14\) item” and “weight-\(6\), value-\(6\) item.”

By applying this merger operation for each \(i\), all items will have weight \(6\), so all that left is to choose them in descending order of their values.

Therefore, by exhaustively searching the values of \(R_1,R_2,R_3\), the problem can be solved in a total of \(O(N\log N)\) time.

Sample code (C++):

#include <bits/stdc++.h>

using namespace std;

using ll = long long;

// Accepts a sequence of (value, count) in ascending order of the values.
// Removes the last `ofs` items, and merge every `g` items.
// Returns the sum of the values of the last `ofs` items, and the sequence of (price, count) after the merge.
pair<ll, vector<pair<ll, ll>>> merge(vector<pair<ll, ll>> ls, int ofs, int g) {
    ll top = 0;
    while (ofs and !ls.empty()) {
        auto [v, k] = ls.back();
        if (k > ofs) {
            top += v * ofs;
            ls.back().second -= ofs;
            break;
        }
        top += v * k;
        ofs -= k;
        ls.pop_back();
    }
    vector<pair<ll, ll>> res;
    ll cnt = 0, sum = 0;
    while (!ls.empty()) {
        auto [v, k] = ls.back();
        ls.pop_back();
        if (cnt) {
            if (cnt + k < g) {
                cnt += k;
                sum += v * k;
                continue;
            }
            k -= g - cnt;
            sum += v * (g - cnt);
            res.emplace_back(sum, 1);
            cnt = sum = 0;
        }
        if (k >= g) {
            res.emplace_back(v * g, k / g);
        }
        cnt = k % g;
        sum = v * cnt;
    }
    return {top, res};
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n;
    ll c;
    cin >> n >> c;
    vector<vector<pair<ll, ll>>> ls(4);
    for (int i = 0; i < n; i++) {
        ll w, v, k;
        cin >> w >> v >> k;
        ls[w].emplace_back(v, k);
    }
    for (int w = 1; w <= 3; w++) {
        sort(ls[w].begin(), ls[w].end());
    }
    ll ans = 0;
    for (int p = 0; p < 6; p++) {
        auto [t1, g1] = merge(ls[1], p, 6);
        for (int q = 0; q < 3; q++) {
            auto [t2, g2] = merge(ls[2], q, 3);
            for (int r = 0; r < 2; r++) {
                auto [t3, g3] = merge(ls[3], r, 2);
                ll rem = c - p - q * 2 - r * 3;
                if (rem < 0) continue;
                rem /= 6;
                ll now = t1 + t2 + t3;
                auto g = g1;
                g.insert(g.end(), g2.begin(), g2.end());
                g.insert(g.end(), g3.begin(), g3.end());
                sort(g.begin(), g.end());
                while (!g.empty()) {
                    auto [v, k] = g.back();
                    if (k >= rem) {
                        now += v * rem;
                        break;
                    }
                    rem -= k;
                    now += v * k;
                    g.pop_back();
                }
                ans = max(ans, now);
            }
        }
    }
    cout << ans << endl;
}

posted:
last update: