Official

C - Lights Out on Tree Editorial by Nyaan


まず、ボタンを押す操作について、次のような性質が成り立ちます。

  • 同じボタンを \(2\) 回押した状態は押す前の状態と同じである
  • ボタンを押す順番を変えても得られる状態は同じ

正当性はナイーブに論証していくことで示せますが、 \(\text{mod }2\) 上の線形代数を利用するとスマートに説明できます。

この事実から、同じボタンは多くとも \(1\) 回しか押さなくてよいことがわかります。つまり、各ボタンについて押すか押さないかの 2 通り、全体で \(2^N\) 通りの押し方を考えればよいです。

ボタン全体の集合を \(B = \lbrace 1,2,\dots,N\rbrace\) とします。クエリで頂点集合 \(S\) が与えられたときに、次の条件を満たすようなボタンの集合 \(T \subseteq B\) を探す問題になります。

  • \(T\) に含まれるボタンを \(1\) 回ずつ押すことで \(S\) に含まれるコインのみをすべて裏返すことができる。\((\ast)\)

ここで、次の事実が成り立ちます。

クエリで与えられる頂点集合 \(S\) を固定したとき、条件 \((\ast)\) を満たすボタンの集合 \(T\) はちょうど 1 通りに定まる。

はじめに \(T\) は少なくとも 1 通り存在することを示します。

頂点集合 \(S\) がサイズ \(1\)\(S = \lbrace v \rbrace\) である場合を考えます。この場合、\(v\) 自身、および \(v\) の全ての子に対応するボタンを押せば \(v\) のコインのみを裏返せるのが確認できます。(このようなボタンの集合を \(T_v\) とします。)

一般の場合を考えます。 \(S = \lbrace v_1, v_2, \dots, v_m \rbrace\) とします。このとき、「\(T_{v_1}, T_{v,2}, \dots, T_{v,m}\) に奇数回登場するボタンからなる集合」は条件を満たします。(これは前述した「同じボタンを \(2\) 回押した状態は押す前の状態と同じである」という性質から示せます。) よって \(S\) に対応する \(T\) が少なくとも \(1\) 個存在するのが言えました。

次に高々 \(1\) 個なのも確認します。非空な頂点集合 \(S\) としてあり得るのは \(2^N - 1\) 通りです。一方、ボタンの集合としてあり得るのは \(2^N\) 通りで、このうち \(T = \emptyset\) の場合は何も裏返らず条件を満たさないので、\(S\) に対応する \(T\) としてあり得るのもまた \(2^N - 1\) 通りです。(\(S\) としてあり得る集合の個数) = (\(T\) としてあり得る集合の個数) なので、各 \(S\) に対応する \(T\)\(1\) 個しか存在しないこともまた明らかです。

以上より 条件を満たすボタンの集合は常に \(1\) 通りなのが示せました。あとは \(S\) に対応する \(T\) のサイズを高速に求められればこの問題を解くことができます。
\(T\) は「\(T\) が少なくとも 1 通り以上存在する証明」の部分で出てきた方法でサイズを計算出来ます。愚直に計算すると \(\mathrm{O}(N)\) 掛かりますが、たとえば頂点を BFS order で並び替えてセグメントツリーなどのデータ構造を用いれば高速に計算できます。
以上の考察でも解く上では十分ですが、更にいくらかの考察を行うと (考察の内容は省略します) 高度なデータ構造は不要になり、次の手順を実装すれば AC することができます。

  • あらかじめ各頂点の子の個数を計算しておく。
  • \(i\) 番目のクエリの答えは \(\displaystyle \sum_{j=1}^{M_i}\) (\(v_{i,j}\) の子の個数 + 1) になる。ただし、頂点集合 \(S_i\) の中に (親,子) の組があるごとに答えから \(2\) を引く。(これは \(S_i\) を連想配列で管理すれば高速に処理できる。)

計算量は \(\mathrm{O}(N + (\sum_i M_i) \log (\sum_i M_i))\) 程度です。

  • 実装例(PyPy)
N, Q = map(int, input().split())
par = [-1, -1] + list(map(int, input().split()))
chd = [0] * (N + 1)
for i in range(2, N + 1):
  chd[par[i]] += 1
for _ in range(Q):
  __, *v = map(int, input().split())
  m = {*v}
  print(sum((chd[c] + (-1 if par[c] in m else 1) for c in v)))

posted:
last update: