Official

G - Delivery on Tree Editorial by yuto1115

解説

まず、カゴが動く経路は全ての辺をちょうど \(2\) 回ずつ通り頂点 \(1\) に戻ってくるようなものに限定されていることから、経路を定めることは、子を \(2\) つ持つ各頂点についてそれらの子を訪れる順序を定めることと等価です。

各ボールを正しく移動させるために操作列に対して課される条件について考えます。まず、以下が定まります。

  • (条件 A)子を \(2\) つ持つ各頂点 \(i\) について、「左の子を先に訪れなければならない」「右の子を先に訪れなければならない」「どちらの子を先に訪れてもよい」のいずれであるか

具体的には、\(S_j\)\(T_j\) の最小共通祖先 \(\text{lca}(S_j,T_j)\)\(S_j\) とも \(T_j\) とも一致しないとき、\(\text{lca}(S_j,T_j)\) の子を訪れる順序に制約がかかります。これらの制約が矛盾を生じた場合、答えは明らかに \(0\) です。

次に、各ボールについてそのボールをカゴに入れる/カゴから出す最適なタイミングを考えることで、以下が定まります。

  • (条件 B)各頂点 \(i\) の各 \(c\in \{\)左の子\(,\)右の子\(,\)\(\}\) について、「\(i\) から \(c\) に移動する直前にカゴに入れるボールの数」および「\(c\) から \(i\) に移動した直後にカゴから出すボールの数」

例えば、\(S_j\)\(T_j\) の祖先であり、\(S_j\) の右の子が \(S_j-T_j\) パス上に存在する場合、\(S_j\) から \(S_j\) の右の子に移動する直前にボール \(j\) をカゴに入れ、\(T_j\) の親から \(T_j\) に移動した直後にカゴから出すのが最適です。

あとは、条件 A を満たす経路のうち、条件 B に沿ってボールを出し入れしたときにカゴの中のボールが \(K\) 個を超えないようなものの数を数えればよいです。これは以下の DP によって求まります。

  • \(dp[i][j]\) : \(j\) 個のボールがカゴに入った状態で頂点 \(i\) に初めて訪れたとき、
    • 部分木 \(i\) 内の頂点を訪れる順番のうち、条件を満たすものは何通りあるか
    • 部分木 \(i\) 内の頂点を全て訪れて \(i\) の親に戻るとき、カゴの中に入っているボールは何個か(なお、この値を \(j'\) としたとき、\(j'-j\) の値は \(j\) によらず一定です。)

遷移等は自然であり特筆すべき点はありませんが、実際のところこの問題の本質は実装です。以下の実装例では、実際には起こり得ない \((i,j)\) のペアに対して DP 値を計算しようとしたときに発生する煩雑さを省くため、メモ化再帰で実装しています。また、子の数による場合分けを避けるために順列全探索を用いています。

実装例 (C++) :

#include<bits/stdc++.h>
#include<atcoder/modint>

using namespace std;
using namespace atcoder;

using mint = modint998244353;

vector<vector<int>> par;
vector<int> dep;

void lca_init() {
    int n = par[0].size();
    for (int k = 0; k < 19; k++) {
        for (int i = 0; i < n; i++) {
            if (par[k][i] == -1) continue;
            par[k + 1][i] = par[k][par[k][i]];
        }
    }
    dep.resize(n);
    for (int i = 1; i < n; i++) {
        dep[i] = dep[par[0][i]] + 1;
    }
}

int la(int u, int d) {
    for (int k = 0; k < 20; k++) {
        if (d >> k & 1) u = par[k][u];
    }
    return u;
}

int lca(int u, int v) {
    if (dep[u] > dep[v]) swap(u, v);
    v = la(v, dep[v] - dep[u]);
    if (u == v) return u;
    for (int k = 19; k >= 0; k--) {
        if (par[k][u] != par[k][v]) {
            u = par[k][u];
            v = par[k][v];
        }
    }
    return par[0][u];
}

int main() {
    int n, m, k;
    cin >> n >> m >> k;
    vector<vector<int>> ch(n);
    par = vector(20, vector<int>(n, -1));
    for (int i = 1; i < n; i++) {
        int p;
        cin >> p;
        ch[--p].push_back(i);
        par[0][i] = p;
    }
    lca_init();
    vector in(n, vector<int>(3)); // 0 : left child  1 : right child  2 : parent
    vector out(n, vector<int>(3));
    vector<int> first(n, -1);
    for (int i = 0; i < m; i++) {
        int s, t;
        cin >> s >> t;
        --s, --t;
        int l = lca(s, t);
        int cs = -1, ct = -1;
        if (l != s) {
            int now = la(s, dep[s] - dep[l] - 1);
            for (int j = 0; j < (int) ch[l].size(); j++) {
                if (ch[l][j] == now) cs = j;
            }
        }
        if (l != t) {
            int now = la(t, dep[t] - dep[l] - 1);
            for (int j = 0; j < (int) ch[l].size(); j++) {
                if (ch[l][j] == now) ct = j;
            }
        }
        if (l == s) {
            ++in[s][ct];
            ++out[t][2];
        } else if (l == t) {
            ++in[s][2];
            ++out[t][cs];
        } else {
            ++in[s][2];
            ++out[t][2];
            if (first[l] == ct) {
                cout << 0 << endl;
                return 0;
            }
            first[l] = cs;
        }
    }
    vector dp(n, vector<pair<int, mint>>(k + 1));
    vector seen(n, vector<bool>(k + 1));
    auto f = [&](auto &f, int i, int j) -> pair<int, mint> {
        if (seen[i][j]) return dp[i][j];
        seen[i][j] = true;
        dp[i][j] = {-1, 0};
        vector<int> ord(ch[i].size());
        iota(ord.begin(), ord.end(), 0);
        do {
            if (!ord.empty() and first[i] == 1 - ord[0]) continue;
            int nj = j;
            mint now = 1;
            nj -= out[i][2];
            for (int c: ord) {
                nj += in[i][c];
                if (nj > k) {
                    nj = -1;
                    break;
                }
                mint m;
                tie(nj, m) = f(f, ch[i][c], nj);
                if (nj == -1) break;
                now *= m;
                nj -= out[i][c];
            }
            if (nj == -1) continue;
            nj += in[i][2];
            if (nj > k) continue;
            dp[i][j].first = nj;
            dp[i][j].second += now;
        } while (next_permutation(ord.begin(), ord.end()));
        return dp[i][j];
    };
    cout << f(f, 0, 0).second.val() << endl;
}

posted:
last update: