Official

C - 貯金箱の管理 / Piggy Bank Management Editorial by admin

Claude 4.6 Opus (Thinking)

概要

\(N\) 個の貯金箱に対して「区間への一律加算」と「1点の値の取得」を効率的に処理する問題です。BIT(Binary Indexed Tree)を用いた差分配列のテクニックで高速に解くことができます。

考察

素朴なアプローチの問題点

種類1の操作(区間加算)を素朴に実装すると、貯金箱 \(L\) から \(R\) まで1つずつ値を加算する必要があり、1回の操作に \(O(N)\) かかります。\(Q\) 回の操作があるため、最悪 \(O(NQ)\) となり、\(N, Q\) が最大 \(2 \times 10^5\) のとき \(4 \times 10^{10}\) 回程度の計算が必要になり、TLE(時間超過)になります。

重要な気づき:差分配列

区間 \([L, R]\) に一律 \(X\) を加算し、ある1点 \(P\) の値を求めたいという状況では、差分配列のテクニックが有効です。

差分配列 \(D\) を用意し、区間 \([L, R]\)\(X\) を加算する操作を以下のように表現します:

  • \(D[L]\)\(+X\)
  • \(D[R+1]\)\(-X\)

すると、位置 \(P\) に加算された合計値は \(D[1] + D[2] + \cdots + D[P]\)(= \(D\)\(P\) までの累積和)で求められます。

具体例: \(N = 5\) で区間 \([2, 4]\)\(+3\) を加算する場合

インデックス 1 2 3 4 5
\(D\) の変化 0 +3 0 0 -3
累積和 0 3 3 3 0

確かに位置 \(2, 3, 4\) だけに \(3\) が加算されています。

さらなる高速化:BIT の利用

差分配列の累積和を毎回 \(O(N)\) で計算していてはまだ遅いです。そこで BIT(Binary Indexed Tree / フェニック木) を使うことで、以下の操作をそれぞれ \(O(\log N)\) で行えます:

  • 1点更新: \(D[i]\) に値を加算する
  • 累積和クエリ: \(D[1] + D[2] + \cdots + D[P]\) を求める

アルゴリズム

  1. 初期金額 \(A\) を配列として保持する。
  2. BIT を差分配列として初期化(全て \(0\))。
  3. 各クエリを処理する:
    • 種類1(区間 \([L, R]\)\(X\) 加算):BIT 上で \(L\)\(+X\)\(R+1\)\(-X\) を加算する。
    • 種類2(位置 \(P\) の値を取得):\(A[P] +\) BIT の \(P\) までの累積和を出力する。

計算量

  • 時間計算量: \(O(N + Q \log N)\)
    • 初期配列の読み込みに \(O(N)\)、各クエリにつき BIT 操作が \(O(\log N)\)
  • 空間計算量: \(O(N)\)
    • 初期配列と BIT 配列の分

実装のポイント

  • BIT のインデックスは 1-indexed で扱うのが一般的です。初期配列 \(A\) は 0-indexed なので、クエリ時に A[P-1] とインデックスをずらす点に注意してください。

  • 差分配列で \(R+1\)\(-X\) を加算する際、\(R+1 > N\) の場合は配列外なので加算不要です(コード中の if R + 1 <= N の条件)。

  • 出力を1つずつ print するのではなく、リストに溜めて最後に sys.stdout.write で一括出力すると、I/O がボトルネックになりにくくなります。

    ソースコード

import sys
input = sys.stdin.readline

def main():
    N, Q = map(int, input().split())
    A = list(map(int, input().split()))
    
    # BIT (Binary Indexed Tree) for range update, point query
    # Using two BITs to support range add and point query
    # For range [L, R] add X:
    #   bit.add(L, X), bit.add(R+1, -X)
    # Point query P: prefix sum up to P
    
    bit = [0] * (N + 2)
    
    def update(i, val):
        while i <= N:
            bit[i] += val
            i += i & (-i)
    
    def query(i):
        s = 0
        while i > 0:
            s += bit[i]
            i -= i & (-i)
        return s
    
    out = []
    for _ in range(Q):
        line = input().split()
        if line[0] == '1':
            L = int(line[1])
            R = int(line[2])
            X = int(line[3])
            update(L, X)
            if R + 1 <= N:
                update(R + 1, -X)
        else:
            P = int(line[1])
            out.append(str(A[P - 1] + query(P)))
    
    sys.stdout.write('\n'.join(out) + '\n')

main()

この解説は claude4.6opus-thinking によって生成されました。

posted:
last update: