Official

E - 宝箱の選択 / Choosing Treasure Chests Editorial by admin

Gemini 3.1 Pro (Thinking)

概要

各宝箱について、その宝箱を必ず選ぶという条件のもとでのナップサック問題の最大価値が、全体の最大価値と一致するかどうかを判定する問題です。

考察

まず、全体の最大価値を求めるだけであれば、有名な「ナップサックDP」を用いて \(O(NM)\) で解くことができます。 しかし、今回の問題では「各宝箱 \(i\) について、それを必ず選んだときの最大価値」を求める必要があります。素朴に考えると、各宝箱を必ず選ぶ(あるいは宝箱 \(i\) を除外した)ナップサックDPを \(N\) 回やり直すことになり、時間計算量が \(O(N^2 M)\) となってしまいます。制約が \(N \leq 1,000, M \leq 3,000\) なので、これでは実行時間制限(TLE)に引っかかってしまいます。

これを解決するために、「前後からのDP(左右からのDP)」というテクニックを使用します。 宝箱 \(i\) を選んだとき、残りの宝箱は「\(1\) 番目から \(i-1\) 番目」と「\(i+1\) 番目から \(N\) 番目」の2つのグループに分かれます。 あらかじめ、前から順番に品物を追加していったDPテーブルと、後ろから順番に追加していったDPテーブルの2つを用意しておけば、宝箱 \(i\) を選ぶときの最大価値を高速に求めることができます。

アルゴリズム

  1. 前からのDP (dp_pref) dp_pref[i][w] を「\(1\) 番目から \(i\) 番目までの宝箱の中から、重さの合計が \(w\) 以下になるように選んだときの価値の最大値」とします。通常のナップサックDPと同様に計算します。
  2. 後ろからのDP (dp_suff) dp_suff[i][w] を「\(i\) 番目から \(N\) 番目までの宝箱の中から、重さの合計が \(w\) 以下になるように選んだときの価値の最大値」とします。後ろから前へ向かって同様に計算します。
  3. 全体の最大価値 すべての宝箱を考慮したときの最大価値 MAX_Vdp_pref[N][M] になります。
  4. 各宝箱の判定 宝箱 \(i\) を必ず選ぶとします。このとき、重さ \(A_i\) を消費し、価値 \(B_i\) を得ます。 残りの容量は \(rem\_w = M - A_i\) となります。 この残り容量を、前側の宝箱と後ろ側の宝箱に分配します。前側に重さ \(w\) を割り当てた場合、後ろ側には重さ \(rem\_w - w\) を割り当てることができます。 したがって、宝箱 \(i\) を含む場合の最大価値は、以下の式で求められます。 $\( \max_{0 \leq w \leq rem\_w} (dp\_pref[i-1][w] + dp\_suff[i+1][rem\_w - w]) + B_i \)$ この値が MAX_V と等しければ Yes、小さければ No を出力します。

計算量

  • 時間計算量: \(O(NM)\)
    • 前からのDPと後ろからのDPの構築にそれぞれ \(O(NM)\) かかります。
    • 各宝箱の判定において、容量の分配を試すのに \(O(M)\) かかるため、全体で \(O(NM)\) となります。
  • 空間計算量: \(O(NM)\)
    • 状態数が \((N+2) \times (M+1)\) のDPテーブルを2つ保持するため、\(O(NM)\) のメモリを使用します。

実装のポイント

  • Pythonなどのスクリプト言語では、多重ループの回数が多いと実行時間が厳しくなることがあります。各宝箱の判定における \(\max\) の計算では、operator.addmap を用いて、リストのスライス同士(pref[:rem_w+1]suff[rem_w::-1])の足し算をC言語レベルで一括処理することで高速化を図っています。

  • suff[rem_w::-1] は、配列をインデックス rem_w から \(0\) まで逆順に取得するスライス操作です。これにより、\(w\)\(rem\_w - w\) のペアの和を、無駄なループを書くことなく簡潔かつ効率よく求めることができます。

    ソースコード

import sys
import operator

def solve():
    input_data = sys.stdin.read().split()
    if not input_data:
        return
    
    N = int(input_data[0])
    M = int(input_data[1])
    
    A = [0] * (N + 1)
    B = [0] * (N + 1)
    
    idx = 2
    for i in range(1, N + 1):
        A[i] = int(input_data[idx])
        B[i] = int(input_data[idx+1])
        idx += 2
        
    dp_pref = [[0] * (M + 1) for _ in range(N + 2)]
    for i in range(1, N + 1):
        w_i = A[i]
        v_i = B[i]
        prev = dp_pref[i-1]
        curr = prev[:]
        for w in range(w_i, M + 1):
            val = prev[w-w_i] + v_i
            if val > curr[w]:
                curr[w] = val
        dp_pref[i] = curr
        
    dp_suff = [[0] * (M + 1) for _ in range(N + 2)]
    for i in range(N, 0, -1):
        w_i = A[i]
        v_i = B[i]
        nxt = dp_suff[i+1]
        curr = nxt[:]
        for w in range(w_i, M + 1):
            val = nxt[w-w_i] + v_i
            if val > curr[w]:
                curr[w] = val
        dp_suff[i] = curr
        
    MAX_V = dp_pref[N][M]
    
    ans = []
    for i in range(1, N + 1):
        w_i = A[i]
        v_i = B[i]
        rem_w = M - w_i
        if rem_w < 0:
            ans.append("No")
            continue
            
        pref = dp_pref[i-1]
        suff = dp_suff[i+1]
        
        max_val = max(map(operator.add, pref[:rem_w+1], suff[rem_w::-1])) + v_i
        
        if max_val == MAX_V:
            ans.append("Yes")
        else:
            ans.append("No")
            
    sys.stdout.write('\n'.join(ans) + '\n')

if __name__ == '__main__':
    solve()

この解説は gemini-3.1-pro-thinking によって生成されました。

posted:
last update: