E - 店舗売上管理 / Store Sales Management 解説 by admin
Claude 4.5 Opus概要
配列に対して「区間の合計を求める」クエリと「1点の値を更新する」クエリを高速に処理する問題です。Fenwick Tree(Binary Indexed Tree)を使うことで効率的に解けます。
考察
素朴なアプローチの問題点
まず、単純な方法を考えてみましょう。
- 区間合計クエリ: 店舗 \(L\) から \(R\) まで順番に足し合わせる → \(O(N)\)
- 更新クエリ: 配列の値を直接書き換える → \(O(1)\)
この方法では、区間合計クエリが最大 \(Q = 2 \times 10^5\) 回、各クエリで最大 \(N = 2 \times 10^5\) 回の加算が必要になり、全体で \(O(NQ) = 4 \times 10^{10}\) 回の計算が必要です。これではTLE(時間超過)になってしまいます。
累積和を使う方法の問題点
累積和を使えば区間合計は \(O(1)\) で求められますが、1点更新のたびに累積和配列全体を更新する必要があり、更新が \(O(N)\) かかってしまいます。
解決策:Fenwick Tree
Fenwick Tree(BIT: Binary Indexed Tree) を使うと、両方の操作を \(O(\log N)\) で行えます。
アルゴリズム
Fenwick Treeとは
Fenwick Treeは、以下の2つの操作を効率的に行うためのデータ構造です:
- 1点加算: 配列の \(i\) 番目の要素に値を加える
- 先頭からの累積和: 配列の \(1\) 番目から \(i\) 番目までの合計を求める
仕組み(簡単な説明)
Fenwick Treeは、配列の要素を「区間ごとにまとめた値」として管理します。各インデックス \(i\) は、自分より前のいくつかの要素の合計を持っています。どの範囲をまとめるかは、\(i\) を2進数で表したときの最下位ビットで決まります。
例えば、\(i = 6 = (110)_2\) の場合、最下位の1のビットは \(2\) なので、\(6\) の位置には \(2\) 個分(インデックス \(5, 6\))の合計が格納されます。
本問題への適用
- 初期化: 各店舗の売上額をFenwick Treeに登録
- クエリ1(区間合計): \(\text{sum}(R) - \text{sum}(L-1)\) で計算
- クエリ2(更新): 新しい値 \(V\) と現在の値 \(S[X]\) の差分 \(V - S[X]\) をFenwick Treeに加算
例: N=5, 初期値 [10, 20, 30, 40, 50]
クエリ「1 2 4」→ 店舗2〜4の合計 = 20 + 30 + 40 = 90
クエリ「2 3 100」→ 店舗3の売上を100に更新
クエリ「1 2 4」→ 店舗2〜4の合計 = 20 + 100 + 40 = 160
計算量
時間計算量: \(O((N + Q) \log N)\)
- 初期化: \(O(N \log N)\)(\(N\) 回の add 操作)
- 各クエリ: \(O(\log N)\)(合計 \(Q\) 回)
空間計算量: \(O(N)\)
- Fenwick Tree用の配列と、現在の売上額を保持する配列
実装のポイント
差分で更新する: クエリ2では「新しい値に置き換える」のではなく、「差分を加算」します。そのため、現在の売上額 \(S[X]\) を別途保持しておき、\(V - S[X]\) を計算してからFenwick Treeに加算します。
1-indexed: Fenwick Treeは1-indexedで実装するのが一般的です。店舗番号が1から始まるので、そのまま対応できます。
高速な入出力: \(N, Q\) が最大 \(2 \times 10^5\) と大きいため、
sys.stdin.buffer.read()で一括入力し、結果も'\n'.join()でまとめて出力することで高速化しています。ソースコード
import sys
class FenwickTree:
def __init__(self, n):
self.n = n
self.tree = [0] * (n + 1)
def add(self, i, delta):
while i <= self.n:
self.tree[i] += delta
i += i & (-i)
def sum(self, i):
s = 0
while i > 0:
s += self.tree[i]
i -= i & (-i)
return s
def range_sum(self, l, r):
return self.sum(r) - self.sum(l - 1)
def main():
input_data = sys.stdin.buffer.read().split()
idx = 0
N = int(input_data[idx]); idx += 1
Q = int(input_data[idx]); idx += 1
S = [0] * (N + 1)
for i in range(1, N + 1):
S[i] = int(input_data[idx]); idx += 1
ft = FenwickTree(N)
for i in range(1, N + 1):
ft.add(i, S[i])
results = []
for _ in range(Q):
query_type = int(input_data[idx]); idx += 1
if query_type == 1:
L = int(input_data[idx]); idx += 1
R = int(input_data[idx]); idx += 1
results.append(ft.range_sum(L, R))
else:
X = int(input_data[idx]); idx += 1
V = int(input_data[idx]); idx += 1
delta = V - S[X]
S[X] = V
ft.add(X, delta)
print('\n'.join(map(str, results)))
if __name__ == "__main__":
main()
この解説は claude4.5opus によって生成されました。
投稿日時:
最終更新: