Official

F - All Included Editorial by en_translator


We solve it with a Dynamic Programming (DP).

Define a DP table by \(\mathrm{dp}[i][s][last] =\) the number of ways to determine the first \(i\) characters, so that \(S_j\ (j\in s)\) are already contained as substrings, and it ends with \(last\). (\(s\) is a subset of \(\{1,2,\ldots,N\}\).) We can assume that \(last\) is always a prefix of any \(S_i\).

For example, if \(S\) is abcdef, bcdgh, dij, dik, and the string constructed so far is ...abcd, then the following transitions are possible:

  • ...abcde: if the next character is e
  • ...a: if the next character is a
  • ...bcdg: if the next character is g
  • ...b: if the next character is b
  • ...di: if the next character is i
  • ...d: if the next character is d
  • ...: otherwise.

The number of states of the DP is \(O(2^NLM)\), where \(M=\sum_i |S_i|\). For each state, there are \(\sigma=26\) possible transitions, so the complexity is \(O(2^NLM\sigma)\).

The next string \(last\) to transition into can be found with an algorithm like Aho-Corasick.

from collections import deque


class AhoCorasick:
    def __init__(self, sigma=26):
        self.node = [[-1] * sigma]
        self.last = [0]
        self.sigma = sigma

    def add(self, arr, ID):
        v = 0
        for c in arr:
            if self.node[v][c] == -1:
                self.node[v][c] = len(self.node)
                self.node.append([-1] * self.sigma)
                self.last.append(0)
            v = self.node[v][c]
        self.last[v] |= 1 << ID

    def build(self):
        link = [0] * len(self.node)
        que = deque()
        for i in range(self.sigma):
            if self.node[0][i] == -1:
                self.node[0][i] = 0
            else:
                link[self.node[0][i]] = 0
                que.append(self.node[0][i])

        while que:
            v = que.popleft()
            self.last[v] |= self.last[link[v]]
            for i in range(self.sigma):
                u = self.node[v][i]
                if u == -1:
                    self.node[v][i] = self.node[link[v]][i]
                else:
                    link[u] = self.node[link[v]][i]
                    que.append(u)


mod = 998244353


N, L = map(int, input().split())
AC = AhoCorasick()
for i in range(N):
    AC.add([ord(c) - ord("a") for c in input()], i)

AC.build()
m = len(AC.node)
dp = [[0] * m for i in range(1 << N)]
dp[0][0] = 1
for _ in range(L):
    ndp = [[0] * m for i in range(1 << N)]
    for bit in range(1 << N):
        for v in range(m):
            for i in range(26):
                to = AC.node[v][i]
                nbit = bit | AC.last[to]
                ndp[nbit][to] = (dp[bit][v] + ndp[nbit][to]) % mod
    dp = ndp
print(sum(dp[-1]) % mod)

posted:
last update: