Official

E - 区間のブロック数 / Number of Blocks in an Interval Editorial by admin

GPT 5.4 High

概要

区間の色を一括で塗り替えながら、任意の区間の「ブロック数」を求める問題です。
各区間について 左端の色・右端の色・ブロック数 を持つ遅延セグメント木を使うと、更新も問い合わせも \(O(\log N)\) で処理できます。

考察

まず、ブロック数とは「隣り合う色が変わる場所の数 \(+1\)」です。
たとえば

  • \(1,1,2,2,2,1\)

なら、色が変わるのは

  • \(1 \to 2\)
  • \(2 \to 1\)

の 2 か所なので、ブロック数は \(2+1=3\) です。

素朴な方法が遅い理由

  • 更新 1 L R X を愚直に行うと、区間 \([L,R]\) の各要素を書き換える必要があり \(O(R-L+1)\)
  • 問い合わせ 2 L R も、区間をなめて色の変化回数を数えると \(O(R-L+1)\)

最悪ではどちらも \(O(N)\) かかるため、全体で \(O(NQ)\) となり間に合いません。

区間をどう要約すればよいか

ある区間を左半分・右半分に分けたとき、その区間全体のブロック数を知るには何が必要でしょうか。

実は、各部分区間について次の 3 つが分かれば十分です。

  • その区間の 左端の色
  • その区間の 右端の色
  • その区間の ブロック数

なぜなら、左区間と右区間をつなげるときに影響するのは、境界の 2 色だけ だからです。

左区間を \((lc_1, rc_1, cnt_1)\)、右区間を \((lc_2, rc_2, cnt_2)\) とすると、

  • もし \(rc_1 = lc_2\) なら、境界の 2 つのブロックはつながる
  • そうでなければ、そのまま別ブロック

したがってマージ後は

  • 左端の色:\(lc_1\)
  • 右端の色:\(rc_2\)
  • ブロック数:\(cnt_1 + cnt_2 - [rc_1 = lc_2]\)

となります。
ここで \([P]\) は、条件 \(P\) が真なら 1、偽なら 0 です。

更新が簡単になる理由

区間 \([L,R]\) をすべて色 \(X\) に塗ると、その区間は 色が全部同じ になります。
つまりその区間の情報は即座に

  • 左端の色:\(X\)
  • 右端の色:\(X\)
  • ブロック数:\(1\)

と分かります。

この性質があるので、区間代入 と相性のよい 遅延セグメント木 が使えます。


アルゴリズム

1. セグメント木の各ノードに持たせる情報

各ノード(ある区間を表す)に対して、次を持たせます。

  • lc : その区間の左端の色
  • rc : その区間の右端の色
  • cnt: その区間のブロック数

葉では 1 マスだけを表すので、色を \(c\) とすると

  • lc = c
  • rc = c
  • cnt = 1

です。

2. 2 つの子区間をマージする

左の子を \(A\)、右の子を \(B\) とすると、

  • lc = A.lc
  • rc = B.rc
  • cnt = A.cnt + B.cnt - (A.rc == B.lc)

で親の情報を作れます。

左区間が 1,1,2 なら

  • lc=1
  • rc=2
  • cnt=2

右区間が 2,2,3 なら

  • lc=2
  • rc=3
  • cnt=2

この 2 つをつなげると、境界は 22 で同じなので 1 つつながり、

  • cnt = 2 + 2 - 1 = 3

となります。実際に並べると 1,1,2,2,2,3 でブロック数は 3 です。

3. 遅延伝播による区間代入

更新 1 L R X では、区間全体を色 \(X\) にします。

あるノードが表す区間が更新区間に完全に含まれるなら、その区間は全部同色になるので

  • lc = X
  • rc = X
  • cnt = 1

にしてしまえばよいです。
さらに、その子にも同じ更新を後で反映できるように lazy\(X\) を記録しておきます。

このようにすることで、毎回葉まで降りずに更新できます。

4. 区間クエリ

問い合わせ 2 L R では、区間 \([L,R]\) のブロック数を求めます。

セグメント木の区間クエリでは、対象区間をいくつかのノード区間に分解して集めます。
ただしこの問題では、ただブロック数を足すだけではだめ です。隣り合う区間の境界でブロックがつながる可能性があるからです。

そこで、クエリ中に次の 2 つの累積情報を持ちます。

  • 左から集めた部分の情報
  • 右から集めた部分の情報

それぞれについて

  • 左側累積:右端の色とブロック数
  • 右側累積:左端の色とブロック数

を管理し、ノードを取るたびに順番を保ってマージします。
最後に左側累積と右側累積を 1 回マージすれば、求める区間全体のブロック数になります。

5. 実装上の区間表現

コードでは内部的に 0-indexed の半開区間 \([l,r)\) を使っています。

入力は 1-indexed の閉区間 \([L,R]\) なので、

  • L -> L-1
  • R -> R

として \([L-1, R)\) に変換しています。


計算量

  • 時間計算量: \(O((N+Q)\log N)\)
  • 空間計算量: \(O(N)\)

初期構築が \(O(N)\)、各更新・各問い合わせがいずれも \(O(\log N)\) です。

実装のポイント

  • ノードの要約情報は 3 つだけで十分
    • 左端色
    • 右端色
    • ブロック数
  • 区間代入されたノードは必ずブロック数 1
    • その区間が全部同じ色になるため
  • クエリでは順番が重要
    • 左から取った区間と右から取った区間を別々に管理し、最後にマージする
  • 未使用の葉に注意
    • セグメント木のサイズは 2 のべき乗にするため、実際には存在しない要素の葉ができる
    • このコードでは cnt = 0 を「空区間」として扱い、pull で丁寧に処理している
  • lazy = 0 を「未更新」の印として使える
    • 問題の色 \(C_i, X\) はすべて 1 以上なので、0 は特別な値として安全に使える

このように、「区間をマージするときに必要な情報だけを持つ」 ことが、この問題の本質です。

ソースコード

import sys

def main():
    data = list(map(int, sys.stdin.buffer.read().split()))
    if not data:
        return

    idx = 0
    N = data[idx]
    Q = data[idx + 1]
    idx += 2

    A = data[idx:idx + N]
    idx += N

    size = 1
    while size < N:
        size <<= 1
    log = size.bit_length() - 1

    m = size << 1
    lc = [0] * m
    rc = [0] * m
    cnt = [0] * m
    lazy = [0] * size

    for i, c in enumerate(A):
        k = size + i
        lc[k] = c
        rc[k] = c
        cnt[k] = 1

    def pull(k):
        l = k << 1
        r = l | 1
        cl = cnt[l]
        cr = cnt[r]
        if cl == 0:
            cnt[k] = cr
            lc[k] = lc[r]
            rc[k] = rc[r]
        elif cr == 0:
            cnt[k] = cl
            lc[k] = lc[l]
            rc[k] = rc[l]
        else:
            cnt[k] = cl + cr - (rc[l] == lc[r])
            lc[k] = lc[l]
            rc[k] = rc[r]

    for k in range(size - 1, 0, -1):
        pull(k)

    def apply(k, x):
        lc[k] = x
        rc[k] = x
        cnt[k] = 1
        if k < size:
            lazy[k] = x

    def push(k):
        x = lazy[k]
        if x:
            l = k << 1
            r = l | 1

            lc[l] = x
            rc[l] = x
            cnt[l] = 1
            if l < size:
                lazy[l] = x

            lc[r] = x
            rc[r] = x
            cnt[r] = 1
            if r < size:
                lazy[r] = x

            lazy[k] = 0

    def range_apply(l, r, x):
        l += size
        r += size
        l0 = l
        r0 = r

        for i in range(log, 0, -1):
            if ((l0 >> i) << i) != l0:
                push(l0 >> i)
            if ((r0 >> i) << i) != r0:
                push((r0 - 1) >> i)

        while l < r:
            if l & 1:
                apply(l, x)
                l += 1
            if r & 1:
                r -= 1
                apply(r, x)
            l >>= 1
            r >>= 1

        for i in range(1, log + 1):
            if ((l0 >> i) << i) != l0:
                pull(l0 >> i)
            if ((r0 >> i) << i) != r0:
                pull((r0 - 1) >> i)

    def range_query_count(l, r):
        l += size
        r += size

        for i in range(log, 0, -1):
            if ((l >> i) << i) != l:
                push(l >> i)
            if ((r >> i) << i) != r:
                push((r - 1) >> i)

        left_r = 0
        left_c = 0
        right_l = 0
        right_c = 0

        while l < r:
            if l & 1:
                c = cnt[l]
                if c:
                    if left_c == 0:
                        left_r = rc[l]
                        left_c = c
                    else:
                        left_c = left_c + c - (left_r == lc[l])
                        left_r = rc[l]
                l += 1
            if r & 1:
                r -= 1
                c = cnt[r]
                if c:
                    if right_c == 0:
                        right_l = lc[r]
                        right_c = c
                    else:
                        right_c = c + right_c - (rc[r] == right_l)
                        right_l = lc[r]
            l >>= 1
            r >>= 1

        if left_c == 0:
            return right_c
        if right_c == 0:
            return left_c
        return left_c + right_c - (left_r == right_l)

    ans = []
    append = ans.append

    for _ in range(Q):
        t = data[idx]
        idx += 1
        if t == 1:
            L = data[idx] - 1
            R = data[idx + 1]
            X = data[idx + 2]
            idx += 3
            range_apply(L, R, X)
        else:
            L = data[idx] - 1
            R = data[idx + 1]
            idx += 2
            append(str(range_query_count(L, R)))

    sys.stdout.write("\n".join(ans))

if __name__ == "__main__":
    main()

この解説は gpt-5.4-high によって生成されました。

posted:
last update: