Official

D - Avoid K Palindrome Editorial by MMNMM


A, B, ? からなる文字列 \(T\) に対して、\(\operatorname{dp} _ T[P]\) を次のように定めます。

  • \(T\) に含まれる ? をそれぞれ A, B に置換してできる文字列のうち、次の \(2\) つの条件をどちらも満たすものの個数
    • 良い文字列である
    • 末尾 \(\min\lbrace K-1,|T|\rbrace\) 文字が文字列 \(P\) と等しい

A, B, ? からなる文字列 \(T\) について、ありえるすべての \(P\) に対する \(\operatorname{dp} _ T[P]\) の値がわかっているとします。
このとき、\(\operatorname{dp} _ {T+\mathtt{A}}[P]\) や \(\operatorname{dp} _ {T+\mathtt{B}}[P],\operatorname{dp} _ {T+\mathtt{?}}[P]\) の値が高速に求められれば、動的計画法を用いてこの問題を解くことができます(ここで、文字列 \(X\) の末尾に文字 \(c\) を追加したものを \(X+c\) と書いて表すこととします)。

\(\operatorname{dp} _ T\) の値から \(\operatorname{dp} _ {T+\mathtt A}\) を求めるアルゴリズムについて考えます(B に対してもほぼ同じで、? に対してはこれらの和を求めればよいです)。

  • はじめ、どのような \(P\) に対しても \(\operatorname{dp} _ {T+\mathtt A}[P]=0\) として初期化する。
  • ありえる \(P\) すべてに対して、次を行う。
    • \(P+\mathtt A\) が長さ \(K\) の回文なら、なにもしない。
    • そうでなければ、\(P+\mathtt A\) の末尾 \(\min\lbrace K-1,|P|+1\rbrace\) 文字を取った文字列を \(P ^ \prime\) として \(\operatorname{dp} _ {T+\mathtt A}[P ^ \prime]\) に \(\operatorname{dp} _ T[P]\) を加える。

これを用いて与えられた文字列 \(S\) に対する \(\operatorname{dp} _ S\) を求めることを考えます。 これは、回文判定に \(\Theta(K)\) 時間かけた場合でも全体で \(O(2 ^ KKN)\) 時間となり、十分高速です。

実装の方針として、DP テーブルのキーとして文字列を使う方針や非負整数を使う方針があります。

文字列を使うと、キーがわかりやすく、長さが \(K\) 未満の文字列に対する DP テーブルを求める際に番兵を用いた実装を使うことで実装をシンプルにしやすいですが、定数倍が悪い場合があります。
非負整数を使うと、時間・空間計算量の定数倍がよいですが、実装が煩雑になる場合があります。

実装例は以下のようになります。

C++ での実装例では DP テーブルのキーを非負整数とし、\(S\) が短い部分の処理を別に分けています。 Python での実装例では DP テーブルのキーを文字列とし、番兵を置くことで実装を単純にしています。

#include <algorithm>
#include <atcoder/modint>
#include <iostream>
#include <numeric>
#include <ranges>
#include <string>
#include <vector>

int main() {
    using namespace std;
    using modint = atcoder::static_modint<998244353>;

    unsigned N, K;
    cin >> N >> K;

    // 長さ K の 0/1 文字列のうち、回文でないもの
    // を二進法で表された整数として解釈したもの
    vector<unsigned> not_palindrome;
    for (unsigned i = 0; i < 1U << K; ++i) {
        bool is_palindrome = true;
        for (unsigned j = 0, k = K - 1; j < k; ++j, --k){
            // 上から見たときと下から見たときで違うビットがあれば
            if ((1 & i >> j) != (1 & i >> k)){
                // 回文ではない
                is_palindrome = false;
                break;
            }
        }
        // 回文でなければ追加
        if (!is_palindrome)
            not_palindrome.emplace_back(i);
    }

    string S;
    cin >> S;

    vector<modint> dp(1U << K - 1), prev(1U << K - 1);
    { // 先頭 K - 1 文字としてありえるものをすべて列挙する
        // a_mask = 'A' の場所だけ 1 になっているビットマスク
        // q_mask = '?' の場所だけ 0 になっているビットマスク
        unsigned a_mask{}, q_mask{};
        for (const auto c : S | views::take(K - 1)) {
            (a_mask *= 2) += c == 'A';
            (q_mask *= 2) += c != '?';
        }

        // q_mask のビットは常に確定させて、'?' の部分を全探索
        for (unsigned i{q_mask}; i < 1U << K - 1; ++i |= q_mask)
            dp[i ^ a_mask] = 1;
    }

    const unsigned mask{(1U << K - 1) - 1};
    for (const auto c : S | views::drop(K - 1)) {
        swap(dp, prev);
        ranges::fill(dp, modint{});

        // 'A' を追加する場合
        if (c != 'B')
            // 回文でなく、末尾が 'A'(0) であるような場合について遷移する
            for(const auto i : not_palindrome | views::filter([](auto i){return ~i & 1;}))
                dp[i & mask] += prev[i / 2];

        // 'B' を追加する場合
        if (c != 'A')
            // 回文でなく、末尾が 'B'(1) であるような場合について遷移する
            for(const auto i : not_palindrome | views::filter([](auto i){return i & 1;}))
                dp[i & mask] += prev[i / 2];
    }

    // すべての接尾辞に対する合計が答え
    cout << reduce(begin(dp), end(dp)).val() << endl;
    return 0;
}
N, K = map(int, input().split())

S = input()

# はじめ、A でも B でもない文字で埋めておく
mp = {'C' * (K - 1) : 1}

for c in S:
    # 末尾に 1 文字追加する
    # 'A' を追加する場合と 'B' を追加する場合を合併
    # 追加できない場合は空の dictionary で合併
    tmp = ({s + 'A' : v for s, v in mp.items()} if c != 'B' else {}) | ({s + 'B' : v for s, v in mp.items()} if c != 'A' else {})
    
    mp = {} # DP テーブルを消去
    
    for s, v in tmp.items():
        if s != s[::-1]: # 回文でない場合、先頭を削って追加する
            if s[1:] in mp:
                mp[s[1:]] += v
            else:
                mp[s[1:]] = v

# 合計を 998244353 で割ったあまりが答え
print(sum(mp.values()) % 998244353)

posted:
last update: