Official

C - 照明スイッチの操作 / Light Switch Operation Editorial by admin

Gemini 3.1 Pro (Thinking)

概要

\(N\) 個の照明に対して、指定された区間にある照明のスイッチを切り替える操作を \(M\) 回行った後、最終的に点灯している照明の個数を求める問題です。

考察

照明の状態は「スイッチが合計で何回切り替えられたか」によって決まります。最初すべて消灯しているため、切り替え回数が 奇数回なら点灯偶数回なら消灯 となります。

素朴なアプローチとして、各操作ごとに \(L_j\) から \(R_j\) までの照明をループで1つずつ反転させる方法が考えられます。しかし、この方法では1回の操作に最大 \(O(N)\) の時間がかかり、全体で \(O(NM)\) の計算量になってしまいます。今回の制約(\(N \leq 10^6, M \leq 2 \times 10^5\))では、最悪ケースで \(2 \times 10^{11}\) 回程度の計算が必要となり、実行時間制限超過(TLE)となってしまいます。

これを解決するためには、区間に対する足し算を高速に行う いもす法(差分配列) という手法を使います。いもす法を使えば、1回の区間更新を \(O(1)\) で行うことができ、最後に \(O(N)\) で全体の切り替え回数を求めることができます。

アルゴリズム

いもす法を用いて、各照明のスイッチが切り替えられた回数を以下の手順で計算します。

  1. 差分配列の準備 長さ \(N+2\) の配列 diff を用意し、すべて \(0\) で初期化します。(1-indexedで扱い、右端 \(+1\) のインデックスにアクセスするため \(N+2\) のサイズにしています)

  2. 区間の両端のみを更新 \(j\) 回目の操作で区間 \([L_j, R_j]\) のスイッチを切り替えるとき、以下のように値を加減算します。

    • diff[L_j]\(1\) を足す(ここから切り替え回数が \(+1\) されるという意味)
    • diff[R_j + 1] から \(1\) を引く(ここから切り替え回数の増加が元に戻るという意味)

具体例:\(N=5\) で区間 \([2, 4]\) を操作する場合、diff[2]\(+1\)diff[5]\(-1\) をします。

  1. 累積和をとる すべての操作が終わった後、左から順に diff の累積和を計算します。 照明 \(i\) における累積和の値が、その照明のスイッチが切り替えられた合計回数になります。

具体例の続き:累積和をとると、インデックス 1, 2, 3, 4, 5 の値はそれぞれ 0, 1, 1, 1, 0 となり、正しく区間 \([2, 4]\)\(1\) が足されていることが分かります。

  1. 点灯している照明を数える 計算した累積和が奇数である(curr % 2 != 0)照明の個数を数え上げ、それを答えとして出力します。

計算量

  • 時間計算量: \(O(N + M)\)
    • 差分配列への加減算(\(M\) 回の操作)に \(O(M)\)
    • 累積和の計算と答えのカウントに \(O(N)\)
    • 全体として非常に高速に動作し、制限時間に余裕で間に合います。
  • 空間計算量: \(O(N)\)
    • 長さ \(N+2\) の配列 diff を保持するためのメモリを使用します。

実装のポイント

  • 入力の高速化: Pythonでは標準入力 input() を何度も呼ぶと遅延が発生しやすいです。そのため、sys.stdin.read().split() を用いて入力全体を一度に読み込み、リスト化して処理することで大幅な高速化を図っています。

  • 配列外参照の防止: \(R_j = N\) の場合、diff[R_j + 1]diff[N + 1] へのアクセスとなります。配列外参照エラー(IndexError)を防ぐため、配列のサイズは \(N+1\) ではなく \(N+2\) 確保しておくのが安全です。

    ソースコード

import sys

def main():
    input_data = sys.stdin.read().split()
    if not input_data:
        return
    
    N = int(input_data[0])
    M = int(input_data[1])
    
    diff = [0] * (N + 2)
    
    idx = 2
    for _ in range(M):
        l = int(input_data[idx])
        r = int(input_data[idx+1])
        diff[l] += 1
        diff[r + 1] -= 1
        idx += 2
        
    ans = 0
    curr = 0
    for i in range(1, N + 1):
        curr += diff[i]
        if curr % 2 != 0:
            ans += 1
            
    print(ans)

if __name__ == '__main__':
    main()

この解説は gemini-3.1-pro-thinking によって生成されました。

posted:
last update: