公式

G - Mode in the Subtree 解説 by Nyaan


今回の問題は一見すると unordered_map (ハッシュマップ) を利用したマージテクで時間計算量 \(\mathrm{O}(N \log N)\) で解けそうです。しかし、今回の問題は \(N \leq 2.5 \times 10^6\) と制約が非常に大きく、unordered_map は時間・空間ともに定数倍が非常に重いことから、マージテクを利用して今回の問題を解くことは非常に難しいでしょう。

想定解では DSU on Tree と呼ばれるアルゴリズムを利用することで、マージテクのデメリットである「過剰なデータ構造を用いることによる定数倍の悪さ」という点を解消して解いています。

DSU on Tree の概要を説明します。DSU on Tree は、「各頂点の部分木に関する情報を、子の計算結果をうまく使い回しながら求める」アルゴリズムです。
各頂点 \(c\) に対して、その部分木に含まれる頂点の情報をまとめたテーブルを手に入れたいです。素朴な方法ではマージテクを利用した方法が考えられますが、それをそのまま実装すると、特に unordered_map のような重いデータ構造を使ったときに定数倍がかなり悪くなってしまいます。
そこで、各頂点について「最も部分木サイズの大きい子」を 1 つ heavy child とし、それ以外を light child とします。heavy child のテーブルだけは捨てずにそのまま使い回し、light child 側の情報は必要になったときだけ後から追加し直します。こうすることで計算量を抑え、定数倍の重いデータ構造を利用することなく問題を解くことが出来ます。

DSU on Tree の具体的なアルゴリズムを説明します。まず、木を重軽分解して(重軽分解については ABC269-Ex 解説 を参照してください) 辺を heavy edge と light edge に塗り分けます。その後、以下の Python 風の疑似コードで表される DFS を行います。

# 頂点の情報を何らかのテーブルで管理することを考える
# add(c): 頂点 c の情報をテーブルに追加する
# query(c): 頂点 c の部分木全体の情報が載ったテーブルに対するクエリに答える
# reset(): テーブルをリセットする

# c は頂点番号
def dfs(c):
  # まず light child の計算をする
  for d in (light children of c):
    dfs(d)
    # light child のテーブルはリセットする
    reset()

  # heavy child の部分木のテーブルを引き継ぐ
  if c is not leaf:
    dfs(heavy child of c)
  
  # テーブルに light child の子孫の情報を追加する
  for d in (light children of c):
    for n in (descendants of d):
      add(n)

  # テーブルに c 自身の情報を反映する
  add(c)

  # クエリに解答する
  query(c)
  
  return

このアルゴリズムの計算量を考えてみましょう。ある頂点 n に対して add 関数が呼び出される回数を考えると、これは (根から頂点 n までのパスに含まれる light edge の本数) + 1 回になることが確認できます。これは、DFS の過程で heavy child のテーブルを引き継いでいることから従います。よってアルゴリズム全体で、

  • add 関数は \(\mathrm{O}(N \log N)\) 回、
  • query 関数は \(N\) 回、
  • reset 関数は \(\mathrm{O}(N)\) 回呼び出されています。

よって、add 関数, query 関数, reset 関数の時間計算量が十分高速であれば、アルゴリズム全体の計算量は \(\mathrm{O}(N \log N)\) になります。

今回の問題では、テーブルとして次の 3 つを持てば十分です。

  • cnt[x]: 現在テーブルに入っている頂点のうち、色 x の頂点数
  • num[t]: 現在テーブルに入っている色のうち、ちょうど t 回現れている色数
  • mx: 現在の最大出現回数

色番号は高々 \(N\) なので、これらはすべて配列で管理できます。頂点 v を追加する add(v) では、x = c_v として cnt[x] を 1 増やし、それに合わせて nummx を更新すればよいです。 query(v) の答えは単に (mx, num[mx]) です。また reset() についても、毎回配列全体を初期化する必要はありません。変更した色や頻度だけを記録しておき、それらだけを 0 に戻すようにすればよいです。

マージテクを利用した解法では unordered_map という非常に定数倍の重いデータ構造を使う関係で、時間・空間ともに効率の悪い解法になってしまっていました。一方で DSU on Tree を用いると、同じ「小さいものを大きいものに足す」という発想を、配列ベースで軽く実現できます。今回の問題のように全ての処理を配列で行える場合は、DSU on Tree の方が実用的であるといえるでしょう。

以上の解法を適切に実装すれば今回の問題を解くことが出来ます。時間計算量は \(\mathrm{O}(N \log N)\)、空間計算量は \(\mathrm{O}(N)\) となり十分高速です。

  • 実装例(Python, codon) 定数倍がやや厳しい場合は実装例のように Euler Tour を HLD と併用すると定数倍を改善できます。
N, state, M, F = [int(x) for x in input().split()]
Q = [int(x) for x in input().split()]
D = [int(x) for x in input().split()]

p = [-1] * N
g = [list() for _ in range(N)]
for i in range(2, N + 1):
    if i <= M:
        p[i - 1] = Q[i - 2] - 1
    else:
        p[i - 1] = state % (i - 1)
        state = (state * 1103515245 + 12345) & ((1 << 31) - 1)
    g[p[i - 1]].append(i - 1)

C = [-1] * N
for i in range(1, N + 1):
    if i <= M:
        C[i - 1] = D[i - 1] - 1
    else:
        C[i - 1] = state % F
        state = (state * 1103515245 + 12345) & ((1 << 31) - 1)

sub = [0] * N
for c in reversed(range(N)):
    sub[c] = 1
    for j in range(len(g[c])):
        d = g[c][j]
        sub[c] += sub[d]
        if sub[d] > sub[g[c][0]]:
            g[c][j], g[c][0] = g[c][0], g[c][j]

idx, euler, down, up = 0, [0] * N, [0] * N, [0] * N
st = [~0, 0]
while len(st) > 0:
    c = st.pop()
    if c >= 0:
        euler[idx] = c
        down[c] = idx
        idx += 1
        for d in reversed(g[c]):
            st.append(~d)
            st.append(d)
    else:
        up[~c] = idx

mx, ans, cnt, freq = 0, 0, [0] * N, [0] * (N + 1)


def add(c):
    global mx
    col = C[c]
    freq[cnt[col]] -= 1
    cnt[col] += 1
    freq[cnt[col]] += 1
    mx = max(mx, cnt[col])


def query(c):
    global ans, mx
    ans += (mx ^ (c + 1)) * (freq[mx] ^ (c + 1))
    ans %= 998244353


def reset(c):
    global mx
    for i in range(down[c], up[c]):
        col = C[euler[i]]
        cnt[col] = 0
    for i in range(mx + 1):
        freq[i] = 0
    mx = 0


def dsu(c):
    global mx, ans
    for i in range(1, len(g[c])):
        dsu(g[c][i])
        reset(g[c][i])
    if len(g[c]):
        dsu(g[c][0])
        for i in range(up[g[c][0]], up[c]):
            add(euler[i])
    add(c)
    query(c)


dsu(0)
print(ans)

投稿日時:
最終更新: