G - Range Knapsack Query Editorial by drogskol

Disjoint Sparse Table を使った解法

Disjoint Sparse Table

公式解説の解き方において、最初に dp テーブルを全て前計算しておくことでこの問題はオンラインで解く事ができます。
これはデータ構造として名前がついていて、Disjoint Sparse Table と呼ばれています。

Disjoint Sparse Table については以下の記事がわかりやすいです。

実装例

#include <bits/stdc++.h>

int main()
{
    int n;
    std::cin >> n;

    std::vector<std::pair<int, int>> a(n);
    for (auto &[w, v] : a)
        std::cin >> w >> v;

    using A500 = std::array<uint64_t, 501>;

    std::vector<A500> init_table(n);
    for (int i = 0; i < n; i++)
    {
        init_table[i].fill(0);
        const auto &[w, v] = a[i];
        for (int j = w; j <= 500; j++)
            init_table[i][j] = v;
    }

    // Disjoint Sparse Table の構築
    int LOG = std::bit_width(unsigned(n - 1));
    std::vector<std::vector<A500>> disjoint_sparse_table(LOG, init_table);

    auto merge = [&](A500 table, const std::pair<int, int> &x) -> A500
    {
        const auto &[w, v] = x;
        for (int i = 500 - w; i >= 0; i--)
            table[i + w] = std::max(table[i + w], table[i] + v);
        return table;
    };

    for (int k = 1; k <= LOG; ++k)
    {
        auto &table = disjoint_sparse_table[k];
        int len = 1 << k;
        for (int L = 0; L + len < n; L += len * 2)
        {
            int M = L + len;
            int R = std::min(M + len, n);

            // 左ブロック(中央→左)
            for (int i = M - 2; i >= L; --i)
                table[i] = merge(table[i + 1], a[i]);

            // 右ブロック(中央→右)
            for (int i = M + 1; i < R; ++i)
                table[i] = merge(table[i - 1], a[i]);
        }
    }

    int q;
    std::cin >> q;
    while (q--)
    {
        int l, r, c;
        std::cin >> l >> r >> c;
        l--;

        if (l + 1 == r)
        {
            std::cout << (a[l].first <= c ? a[l].second : 0) << "\n";
            continue;
        }

        const auto &table = disjoint_sparse_table[std::bit_width(unsigned(l ^ (r - 1))) - 1];
        const auto &left = table[l];
        const auto &right = table[r - 1];

        uint64_t ans = 0;
        for (int i = 0; i <= c; i++)
            ans = std::max(ans, left[i] + right[c - i]);
        std::cout << ans << "\n";
    }
}

計算量解析

公式解説と同様に \( K=\max C_j\) とします。

時間計算量

  • 前計算: \(O(KN \log N)\)
  • クエリ処理: \(O(QK)\)

処理も軽いため、これは制約内で十分間に合います。

空間計算量

テーブルは長さ \(K\) の 64 bit 整数配列を \(N\log N\) 個ほど管理するため、ざっくり

\(K \times N\log N \times 64 \simeq 50\times 20000 \times 15 \times 64 \simeq 9 \times 10^9\) bit

となります。

この問題の制約は約 \(8 \times 10^9\) bit なため、制約違反です。
実際、上で貼ったコードは MLE します。
https://atcoder.jp/contests/abc426/submissions/69966867 (MLE)

そこでメモリ使用量の削減を考えます。
テーブルの値の最大値は \(500 \times 10^9\) なので、40bit あれば十分です。
したがって、テーブルの値の型を 64bit から 40bit に削減すると、

\(9 \times 10^9 \times (40 / 64)\simeq 6\times 10^9\) bit

となり、メモリ制限内に収まります。

今回の40bit 整数は加算と大小比較のみ行えれば良いので、例えば以下のようにして実装出来ます。(アラインメントに注意してください)

#pragma pack(push, 1)
class uint40_t
{
    uint32_t low; // 下位32ビット
    uint8_t high; // 上位8ビット
public:
    uint40_t(uint64_t val = 0)
        : low(static_cast<uint32_t>(val)),
          high(static_cast<uint8_t>(val >> 32)) {}

    uint64_t to_uint64() const
    {
        return (static_cast<uint64_t>(high) << 32) | low;
    }

    friend uint40_t operator+(const uint40_t &a, const uint40_t &b)
    {
        return uint40_t(a.to_uint64() + b.to_uint64());
    }

    auto operator<=>(const uint40_t &other) const
    {
        return to_uint64() <=> other.to_uint64();
    }
};
#pragma pack(pop)

これを用いて最初の実装の uint64_t を uint40_t に置き換えると MLE が解消し AC することが出来ます。
https://atcoder.jp/contests/abc426/submissions/69966931 (1131 ms)

posted:
last update: