Official

F - All Included Editorial by toam


動的計画法で解きます.

\(\mathrm{dp}[i][s][last]\) で「文字列の \(i\) 文字目まで決めたとき,すでに \(S_j\ (j\in s)\) を含んでいて,文字列の末尾が \(last\) であるようなもの」という dp を考えます(\(s\)\(\{1,2,\ldots,N\}\) の部分集合).このとき,\(last\) はいずれかの \(S_i\) の prefix であるようなものしか考えなくて十分です.

例えば,\(S\)abcdef, bcdgh, dij, dik で,今作っている文字列の末尾が ...abcd だったとき,遷移先としては

  • ...abcde: 次の文字が e のとき
  • ...a: 次の文字が a のとき
  • ...bcdg: 次の文字が g のとき
  • ...b: 次の文字が b のとき
  • ...di: 次の文字が i のとき
  • ...d: 次の文字が d のとき
  • ...: 上のいずれでもない

のようになります.

dp の状態数は \(M=\sum_i |S_i|\) として \(O(2^NLM)\) です.各状態に対して,遷移は \(\sigma=26\) 通りあるので,計算量は \(O(2^NLM\sigma)\) になります.

dp の遷移先は Aho–Corasick 法などで求められます.

from collections import deque


class AhoCorasick:
    def __init__(self, sigma=26):
        self.node = [[-1] * sigma]
        self.last = [0]
        self.sigma = sigma

    def add(self, arr, ID):
        v = 0
        for c in arr:
            if self.node[v][c] == -1:
                self.node[v][c] = len(self.node)
                self.node.append([-1] * self.sigma)
                self.last.append(0)
            v = self.node[v][c]
        self.last[v] |= 1 << ID

    def build(self):
        link = [0] * len(self.node)
        que = deque()
        for i in range(self.sigma):
            if self.node[0][i] == -1:
                self.node[0][i] = 0
            else:
                link[self.node[0][i]] = 0
                que.append(self.node[0][i])

        while que:
            v = que.popleft()
            self.last[v] |= self.last[link[v]]
            for i in range(self.sigma):
                u = self.node[v][i]
                if u == -1:
                    self.node[v][i] = self.node[link[v]][i]
                else:
                    link[u] = self.node[link[v]][i]
                    que.append(u)


mod = 998244353


N, L = map(int, input().split())
AC = AhoCorasick()
for i in range(N):
    AC.add([ord(c) - ord("a") for c in input()], i)

AC.build()
m = len(AC.node)
dp = [[0] * m for i in range(1 << N)]
dp[0][0] = 1
for _ in range(L):
    ndp = [[0] * m for i in range(1 << N)]
    for bit in range(1 << N):
        for v in range(m):
            for i in range(26):
                to = AC.node[v][i]
                nbit = bit | AC.last[to]
                ndp[nbit][to] = (dp[bit][v] + ndp[nbit][to]) % mod
    dp = ndp
print(sum(dp[-1]) % mod)

posted:
last update: