Submission #61507127


Source Code Expand

def main
  n = gets.to_i
  a = gets.split.map(&:to_i)
  q = gets.to_i
  wm = WaveletMatrix.new(a)
  
  
  ans = [0]
  q.times do
    l, r, x = gets.split.map(&:to_i)
    l ^= ans[-1]
    r ^= ans[-1]
    x ^= ans[-1]
    
    unless 1<=l && l<=r && r<=n
      exit! 1
    end
    ans << wm.sum_less_than_x(l - 1, r, x + 1)
  end
  
  puts ans[1..]
end

class WaveletMatrix
  attr_reader :n

  def initialize(array)
    @n = array.size
    return if @n == 0

    # 最大値のビット幅を求める
    # max_value = array.max
    @bit_width = array.max.bit_length + 1
    # while max_value > 0
    #   @bit_width += 1
    #   max_value >>= 1
    # end

    # 各レベルの情報を格納する配列を準備
    # @matrix[h][i]       : レベルhでの i番目のビット (0 or 1)
    # @cumulative[h][i]   : レベルhでの rank1(0~i未満の1の数) を求めるための累積配列
    # @sum_level[h][i]    : レベルhにおける「左から i 個目までの合計」(curr の並びに対応)
    # @mid_point[h]       : レベルhにおける「0側ブロックが終わる位置」
    # @zeros_count[h]     : レベルhにおける 0 ビットの個数
    @matrix     = Array.new(@bit_width) { Array.new(@n, 0) }
    @cumulative = Array.new(@bit_width) { Array.new(@n + 1, 0) }
    @sum_level  = Array.new(@bit_width) { Array.new(@n + 1, 0) }
    @mid_point  = Array.new(@bit_width, 0)
    @zeros_count= Array.new(@bit_width, 0)

    curr = array.dup

    @bit_width.times do |h|
      # 上位ビットから順に処理するので、bit_index はこうなる
      bit_index = @bit_width - 1 - h

      # 0側に振り分けられる要素と 1側に振り分けられる要素
      zeros = []
      ones  = []

      # それぞれの累積和(一時的)
      sum_zeros = [0]
      sum_ones  = [0]

      # curr[i] の bit_index ビットを見て 0/1 に振り分ける
      curr.each_with_index do |val, i|
        b = (val >> bit_index) & 1
        @matrix[h][i] = b
        if b == 0
          zeros << val
          sum_zeros << sum_zeros[-1] + val
        else
          ones << val
          sum_ones << sum_ones[-1] + val
        end
      end

      # cumulative[h] を構築 (rank1 計算用)
      @cumulative[h][0] = 0
      @n.times do |i|
        @cumulative[h][i + 1] = @cumulative[h][i] + @matrix[h][i]
      end

      # 0側の要素数
      @zeros_count[h] = zeros.size
      @mid_point[h]   = zeros.size

      # curr = 0側 + 1側 に再構築(次のビットレベルへ渡す)
      curr = zeros + ones

      # sum_level[h] の構築
      #   → curr[0..i-1] の合計値を @sum_level[h][i] に持たせる
      @sum_level[h][0] = 0
      @n.times do |i|
        @sum_level[h][i + 1] = @sum_level[h][i] + curr[i]
      end
    end
  end

  #----------------------------------------------------------------------
  # rank1(level, pos): レベル level における [0, pos) の範囲に登場する '1' の個数
  # rank0(level, pos): 上記の 1 の反転
  #----------------------------------------------------------------------
  def rank1(level, pos)
    return 0 if pos <= 0
    return @cumulative[level][@n] if pos > @n
    @cumulative[level][pos]
  end

  def rank0(level, pos)
    pos - rank1(level, pos)
  end

  #----------------------------------------------------------------------
  # access(k): 元の配列において index=k の要素を取得する (通常の Wavelet Matrix と同様)
  #----------------------------------------------------------------------
  def access(k)
    return nil if k < 0 || k >= @n
    result = 0
    @bit_width.times do |h|
      b = @matrix[h][k]
      result = (result << 1) | b
      if b == 0
        k = rank0(h, k)
      else
        k = @mid_point[h] + rank1(h, k)
      end
    end
    result
  end

  #----------------------------------------------------------------------
  # kth_smallest(l, r, k): 区間 [l, r) の中で k番目に小さい要素 (0-based)
  #----------------------------------------------------------------------
  def kth_smallest(l, r, k)
    return nil if l < 0 || r > @n || l >= r || k < 0 || k >= (r - l)
    result = 0
    current_l = l
    current_r = r

    @bit_width.times do |h|
      zeros_count_in_range = rank0(h, current_r) - rank0(h, current_l)
      if k < zeros_count_in_range
        # ビットが0側へ
        current_l = rank0(h, current_l)
        current_r = rank0(h, current_r)
      else
        # ビットが1側へ
        k -= zeros_count_in_range
        result |= (1 << (@bit_width - 1 - h))
        current_l = @mid_point[h] + rank1(h, current_l)
        current_r = @mid_point[h] + rank1(h, current_r)
      end
    end
    result
  end

  def kth_largest(l, r, k)
    # r-l 個のうち、上から k番目 = 下から (r-l-1-k)番目
    kth_smallest(l, r, (r - l - 1) - k)
  end

  #----------------------------------------------------------------------
  # count(l, r, x): 区間 [l, r) において、値が x に等しい要素数
  #   = freq(< x+1) - freq(< x)
  #----------------------------------------------------------------------
  def count(l, r, x)
    range_freq(l, r, x + 1) - range_freq(l, r, x)
  end

  #----------------------------------------------------------------------
  # range_freq(l, r, upper): 区間 [l, r) で、値が [0, upper) に属する要素数
  #   (元の実装と同様。上限 upper のみを使って数え上げる)
  #----------------------------------------------------------------------
  def range_freq(l, r, upper)
    return 0 if l >= r || upper <= 0
    return (r - l) if upper > (1 << @bit_width)  # 全部入る

    result = 0
    current_l = l
    current_r = r

    @bit_width.times do |h|
      bit = (@bit_width - 1 - h)
      if ((upper >> bit) & 1) == 1
        # 0側は丸ごと加算して、1側だけ探索へ
        zero_count = rank0(h, current_r) - rank0(h, current_l)
        result += zero_count
        current_l = @mid_point[h] + rank1(h, current_l)
        current_r = @mid_point[h] + rank1(h, current_r)
      else
        # 0側だけ探索へ (1側は足さない)
        current_l = rank0(h, current_l)
        current_r = rank0(h, current_r)
      end
    end

    result
  end

  #----------------------------------------------------------------------
  # prev_value(l, r, upper): 区間 [l, r) において upper 未満で最大の値
  #   (無い場合は nil)
  #----------------------------------------------------------------------
  def prev_value(l, r, upper)
    return nil if l >= r || upper <= 0
    val = upper - 1
    while val >= 0
      cnt = range_freq(l, r, val + 1) - range_freq(l, r, val)
      return val if cnt > 0
      val -= 1
    end
    nil
  end

  #----------------------------------------------------------------------
  # next_value(l, r, lower): 区間 [l, r) において lower 以上で最小の値
  #   (無い場合は nil)
  #----------------------------------------------------------------------
  def next_value(l, r, lower)
    return nil if l >= r || lower < 0
    max_value = (1 << @bit_width) - 1
    return nil if lower > max_value

    val = lower
    while val <= max_value
      cnt = range_freq(l, r, val + 1) - range_freq(l, r, val)
      return val if cnt > 0
      val += 1
    end
    nil
  end

  #----------------------------------------------------------------------
  # (NEW) sum_less_than_x(l, r, x):
  #   区間 [l, r) において、「値が x 未満」の要素の合計
  #
  #   上限 x をビットで見ながら、
  #     - 「そのビットが 0」 => 1側(=そのビットが 1 の要素)はすべて排除
  #     - 「そのビットが 1」 => 0側はすべて取り込んで合計に加算し、1側を次に探索
  #
  #   という形で、Wavelet Matrix の各レベルを追いかけます。
  #----------------------------------------------------------------------
  def sum_less_than_x(l, r, x)
    return 0 if l >= r || x <= 0
    # x がビット幅を超えているときは、区間の要素全部を合計してよい
    if x > (1 << @bit_width)
      return sum_range_all(l, r)
    end

    sum = 0
    current_l = l
    current_r = r

    @bit_width.times do |h|
      bit_index = @bit_width - 1 - h
      # x の現在ビットは?
      bit_x = (x >> bit_index) & 1

      if bit_x == 1
        # 0側(このビットが0)の要素は全て「x 未満」確定なので
        # 0側の要素合計を一括で足す
        zero_l = rank0(h, current_l)
        zero_r = rank0(h, current_r)
        zero_count = zero_r - zero_l

        # レベル h で再構築された配列 curr において
        #  0側の要素たちはインデックス [zero_l, zero_r)
        # の範囲に該当する
        #  → sum_level[h][zero_r] - sum_level[h][zero_l] で合計値が取れる
        sum += (sum_level_at(h, zero_r) - sum_level_at(h, zero_l))

        # その後、1側へ進む (bit=1 側)
        current_l = @mid_point[h] + rank1(h, current_l)
        current_r = @mid_point[h] + rank1(h, current_r)
      else
        # bit_x == 0 の場合: 1側は「x未満」にならないのでスルー
        # 0側へ進む
        current_l = rank0(h, current_l)
        current_r = rank0(h, current_r)
      end
    end

    sum
  end

  #----------------------------------------------------------------------
  # (補助) sum_range_all(l, r): 区間 [l, r) の要素合計を全部足す
  #   → 最下位レベル(@bit_width-1)の配列(=最終的に再構成された curr) を使うと簡単
  #----------------------------------------------------------------------
  def sum_range_all(l, r)
    return 0 if l >= r
    # 最後のレベル(@bit_width - 1)における sum_level を利用
    h = @bit_width - 1
    sum_level_at(h, r) - sum_level_at(h, l)
  end

  #----------------------------------------------------------------------
  # (補助) sum_level_at(h, pos): レベル h における 0~pos の合計
  #----------------------------------------------------------------------
  def sum_level_at(h, pos)
    if pos < 0
      0
    elsif pos > @n
      @sum_level[h][@n]
    else
      @sum_level[h][pos]
    end
  end
end

main

Submission Info

Submission Time
Task G - Smaller Sum
User zeronosu77108
Language Ruby (ruby 3.2.2)
Score 600
Code Size 10534 Byte
Status AC
Exec Time 3496 ms
Memory 245516 KiB

Judge Result

Set Name Sample All
Score / Max Score 0 / 0 600 / 600
Status
AC × 1
AC × 51
Set Name Test Cases
Sample sample_01.txt
All sample_01.txt, test_01.txt, test_02.txt, test_03.txt, test_04.txt, test_05.txt, test_06.txt, test_07.txt, test_08.txt, test_09.txt, test_10.txt, test_11.txt, test_12.txt, test_13.txt, test_14.txt, test_15.txt, test_16.txt, test_17.txt, test_18.txt, test_19.txt, test_20.txt, test_21.txt, test_22.txt, test_23.txt, test_24.txt, test_25.txt, test_26.txt, test_27.txt, test_28.txt, test_29.txt, test_30.txt, test_31.txt, test_32.txt, test_33.txt, test_34.txt, test_35.txt, test_36.txt, test_37.txt, test_38.txt, test_39.txt, test_40.txt, test_41.txt, test_42.txt, test_43.txt, test_44.txt, test_45.txt, test_46.txt, test_47.txt, test_48.txt, test_49.txt, test_50.txt
Case Name Status Exec Time Memory
sample_01.txt AC 160 ms 17432 KiB
test_01.txt AC 98 ms 21472 KiB
test_02.txt AC 107 ms 27084 KiB
test_03.txt AC 537 ms 48308 KiB
test_04.txt AC 552 ms 106484 KiB
test_05.txt AC 213 ms 21396 KiB
test_06.txt AC 1775 ms 118096 KiB
test_07.txt AC 2881 ms 187408 KiB
test_08.txt AC 3091 ms 209584 KiB
test_09.txt AC 732 ms 23668 KiB
test_10.txt AC 205 ms 29216 KiB
test_11.txt AC 401 ms 24284 KiB
test_12.txt AC 394 ms 64036 KiB
test_13.txt AC 287 ms 54180 KiB
test_14.txt AC 626 ms 47652 KiB
test_15.txt AC 772 ms 70052 KiB
test_16.txt AC 2181 ms 120920 KiB
test_17.txt AC 1397 ms 167324 KiB
test_18.txt AC 399 ms 31428 KiB
test_19.txt AC 1520 ms 113384 KiB
test_20.txt AC 344 ms 20276 KiB
test_21.txt AC 506 ms 50056 KiB
test_22.txt AC 672 ms 68404 KiB
test_23.txt AC 746 ms 77872 KiB
test_24.txt AC 1333 ms 113516 KiB
test_25.txt AC 2093 ms 174120 KiB
test_26.txt AC 2789 ms 214408 KiB
test_27.txt AC 3357 ms 241820 KiB
test_28.txt AC 3429 ms 241352 KiB
test_29.txt AC 3486 ms 241152 KiB
test_30.txt AC 494 ms 45684 KiB
test_31.txt AC 552 ms 49116 KiB
test_32.txt AC 746 ms 67220 KiB
test_33.txt AC 812 ms 77828 KiB
test_34.txt AC 1423 ms 113284 KiB
test_35.txt AC 2145 ms 175700 KiB
test_36.txt AC 2848 ms 212564 KiB
test_37.txt AC 3477 ms 242104 KiB
test_38.txt AC 3496 ms 241024 KiB
test_39.txt AC 3496 ms 241648 KiB
test_40.txt AC 478 ms 43456 KiB
test_41.txt AC 501 ms 48348 KiB
test_42.txt AC 636 ms 68392 KiB
test_43.txt AC 722 ms 77628 KiB
test_44.txt AC 1217 ms 113856 KiB
test_45.txt AC 1917 ms 174904 KiB
test_46.txt AC 2654 ms 212248 KiB
test_47.txt AC 3437 ms 241172 KiB
test_48.txt AC 3225 ms 242832 KiB
test_49.txt AC 3276 ms 245516 KiB
test_50.txt AC 432 ms 45612 KiB