Official

E - King Bombee Editorial by tatyam


「数列 \(A\) の中に整数 \(X\) は偶数回出現する」という条件が厄介なので、まずはこれを取り除いて考えてみましょう。すると、「頂点 \(S\) から辺を \(K\) 回通って頂点 \(T\) へ行く方法は何通り?」という問題になります。これは、ABC242 C - 1111gal password のように DP (動的計画法) で解くことができます。

具体的には、

\(dp[i][j] = {}\)(頂点 \(S\) から辺を \(i\) 回通って頂点 \(j\) へ行く方法の数)

を計算すれば良いです。漸化式は、頂点 \(i\) と直接辺でつながっている頂点の集合を \(\text{adj}(i)\) として、

\(dp[0][j] = \begin{cases} 1 & (j = S) \\ 0 & (j ≠ S)\end{cases}\)
\(dp[i + 1][j] = \displaystyle\sum_{k \in \text{adj}(j)}dp[i][k]\)

となります。\(dp[K][T]\) が求める答えです。時間計算量は \(O((N + M) K)\) です。

さて、「数列 \(A\) の中に整数 \(X\) は偶数回出現する」という条件を扱うためには、DP の計算に「これまでに頂点 \(X\) を通った回数が偶数か奇数か」の状態を持たせる必要があります。DP の計算中に頂点 \(X\) を通ったら、偶数から奇数へ、奇数から偶数へ状態を変化させれば良いです。

このような状態を追加するには、\(dp\) の添字を \(1\) つ増やせば良いです。すなわち、

\(dp[i][j][x] = {}\)(頂点 \(S\) から辺を \(i\) 回通って頂点 \(j\) へ行き、途中で頂点 \(X\) を通った回数\({}\bmod 2\)\(x\) であるような方法の数)

とします。漸化式は、

\(dp[0][j][0] = \begin{cases} 1 & (j = S) \\ 0 & (j ≠ S)\end{cases}\)
\(dp[0][j][1] = 0\)
\(dp[i + 1][j][x] = \displaystyle\sum_{k \in \text{adj}(j)}dp[i][k][x]\ (j ≠ X)\)
\(dp[i + 1][X][x] = \displaystyle\sum_{k \in \text{adj}(X)}dp[i][k][1 - x]\)

となります。\(dp[K][T][0]\) が求める答えです。時間計算量は先ほどと変わらず \(O((N + M) K)\) なので、これでこの問題を解くことができました。

\(998244353\) で割ったあまりについては、各計算の後に毎回 \(998244353\) で割ったあまりを取るようにするか、AC Library の modint を使えば良いです。

実装例 (C++)

#include <atcoder/modint>
#include <array>
#include <iostream>
#include <vector>
using namespace std;
using Modint = atcoder::modint998244353;

int main(){
    int N, M, K, S, T, X;
    cin >> N >> M >> K >> S >> T >> X;
    S--; T--; X--;
    vector<pair<int, int>> edge(M);
    for(auto& [U, V] : edge){
        cin >> U >> V;
        U--; V--;
    }
    vector dp(K + 1, vector(N, array<Modint, 2>{0, 0}));
    dp[0][S][0] = 1;
    for(int i = 0; i < K; i++){
        for(auto [U, V] : edge) for(int x : {0, 1}){
            dp[i + 1][V][x ^ (V == X)] += dp[i][U][x];
            dp[i + 1][U][x ^ (U == X)] += dp[i][V][x];
        }
    }
    cout << dp[K][T][0].val() << endl;
}

実装例 (Python)

MOD = 998244353
N, M, K, S, T, X = map(int, input().split())
S -= 1
T -= 1
X -= 1

edge = []
for i in range(M):
    U, V = map(int, input().split())
    U -= 1
    V -= 1
    edge.append((U, V))

dp = [[[0] * N for i in range(K + 1)] for x in range(2)]
dp[0][0][S] = 1

for i in range(K):
    for U, V in edge:
        for x in range(2):
            dp[x][i + 1][V] += dp[x ^ (V == X)][i][U]
            if dp[x][i + 1][V] >= MOD:
                dp[x][i + 1][V] -= MOD
            dp[x][i + 1][U] += dp[x ^ (U == X)][i][V]
            if dp[x][i + 1][U] >= MOD:
                dp[x][i + 1][U] -= MOD

print(dp[0][K][T])

posted:
last update: