G - Leaf Color Editorial by tonegawa

マージテクと木DPを使った解法

方針

与えられる木の根を頂点\(1\)として、条件を満たす頂点の部分集合のうち最も深さの小さい頂点が\(v \) であるようなものは何通りあるか?

これを全ての\(v\) について計算して足し合わせることを目標とします。

DPテーブルの持ち方

以下のようにDPテーブルを持ちます。

  • \(dp[v][c] := \) 頂点 \(v\) の部分木の頂点集合の部分集合であって、 誘導部分グラフが木になり、かつ\(v\) を含み、かつ葉になることが確定している頂点の色が全て \(c\) であるものの場合の数

葉になることが確定している頂点とは、\(v\)が孤立点でない場合葉の頂点集合から \(v\) を除いたものを指します。 (\(v\)は親頂点と繋がれて葉でなくなる可能性があるためです)

答えへの寄与の計算

誘導部分グラフについて\(3\) 通りの状態を考える必要があります。

  1. \(v\) が孤立点
  2. \(v\) の子頂点を\(1\) つだけ含む
  3. \(v\) の子頂点を\(2\) つ以上含む

1の場合は任意の\(v\) について常に1通り存在します。 (下記コードの(1))

2の場合\(v\) が葉となるため、子頂点側に含まれる葉が\(v\)と同じ色であることが必要十分条件です。(コードの(2)部分)

3の場合 \(v\) は葉にならないため、子頂点側に含まれる葉が全て同じ色であることが必要十分条件です。(コードの(3)部分)

DPテーブルの更新

\(v\) の子頂点 \(ch\) 側の部分木をマージするとき、\(ch\) 側の部分木に含まれる色 \(c\)に対して

\[dp[v][c] \larr dp[v][c] + dp[v][c] * dp[ch][c] + dp[ch][c]\]

と更新できます。右辺の項はそれぞれ(\(ch\)以前の子頂点のみ使う)、(\(ch\)以前と\(ch\)両方使う)、(\(ch\) のみ使う) 場合に相当します。

マージには最悪で \(ch\) の部分木サイズ分の計算が必要ですが、サイズの小さい方のテーブルをサイズの大きい方のテーブルにマージすることにすると全体の時間計算量が\(O(N logN)\) になります。(より詳しくはマージテクと検索してください)

今回の場合式が対称的なのでテーブルをswapするだけで良いです。

コード

#include <iostream>
#include <vector>
#include <unordered_map>
#include <atcoder/modint>
using namespace std;
using mint = atcoder::static_modint<998244353>;

int main(){
  cin.tie(nullptr);
  ios::sync_with_stdio(false);
  int n;
  cin >> n;
  vector<int> a(n);
  for(int i = 0; i < n; i++) cin >> a[i];
  vector<vector<int>> g(n);
  // グラフの作成
  for(int i = 0; i < n - 1; i++){
    int x, y;
    cin >> x >> y;
    x--, y--;
    g[x].push_back(y);
    g[y].push_back(x);
  }

  vector<unordered_map<int, mint>> D(n);
  mint ans = 0;
  // mpにキーがkである要素がある場合その値を, ない場合0を返す
  auto get = [](unordered_map<int, mint> &mp, int k) -> mint {
    auto itr = mp.find(k);
    return itr == mp.end() ? 0 : itr->second;
  };
  // mpにキーがkである要素がある場合xを足す, ない場合xを追加
  auto add = [](unordered_map<int, mint> &mp, int k, mint x) -> void {
    auto itr = mp.find(k);
    if(itr == mp.end()) mp.emplace(k, x);
    else itr->second += x;
  };
  // 木dp
  auto f = [&](auto &&f, int v, int p) -> void {
    for(int c : g[v]){
      if(c == p) continue;
      f(f, c, v);
      ans += get(D[c], a[v]); // (2)
      
      // |D[v]| >= |D[c]|にする(マージテク)
      if(D[v].size() < D[c].size()) D[v].swap(D[c]);
      
      for(auto [color, cnt] : D[c]){
        mint x = get(D[v], color);
        mint y = cnt * x;
        ans += y; // (3)
        add(D[v], color, y + cnt);
      }
    }
    ans++; // (1)
    add(D[v], a[v], 1);
  };
  f(f, 0, -1);
  cout << ans.val() << '\n';
}

提出URL: https://atcoder.jp/contests/abc340/submissions/50218488 (178ms)

posted:
last update: