F - All Included Editorial by en_translator
We solve it with a Dynamic Programming (DP).
Define a DP table by \(\mathrm{dp}[i][s][last] =\) the number of ways to determine the first \(i\) characters, so that \(S_j\ (j\in s)\) are already contained as substrings, and it ends with \(last\). (\(s\) is a subset of \(\{1,2,\ldots,N\}\).) We can assume that \(last\) is always a prefix of any \(S_i\).
For example, if \(S\) is abcdef
, bcdgh
, dij
, dik
, and the string constructed so far is ...abcd
, then the following transitions are possible:
...abcde
: if the next character ise
...a
: if the next character isa
...bcdg
: if the next character isg
...b
: if the next character isb
...di
: if the next character isi
...d
: if the next character isd
...
: otherwise.
The number of states of the DP is \(O(2^NLM)\), where \(M=\sum_i |S_i|\). For each state, there are \(\sigma=26\) possible transitions, so the complexity is \(O(2^NLM\sigma)\).
The next string \(last\) to transition into can be found with an algorithm like Aho-Corasick.
from collections import deque
class AhoCorasick:
def __init__(self, sigma=26):
self.node = [[-1] * sigma]
self.last = [0]
self.sigma = sigma
def add(self, arr, ID):
v = 0
for c in arr:
if self.node[v][c] == -1:
self.node[v][c] = len(self.node)
self.node.append([-1] * self.sigma)
self.last.append(0)
v = self.node[v][c]
self.last[v] |= 1 << ID
def build(self):
link = [0] * len(self.node)
que = deque()
for i in range(self.sigma):
if self.node[0][i] == -1:
self.node[0][i] = 0
else:
link[self.node[0][i]] = 0
que.append(self.node[0][i])
while que:
v = que.popleft()
self.last[v] |= self.last[link[v]]
for i in range(self.sigma):
u = self.node[v][i]
if u == -1:
self.node[v][i] = self.node[link[v]][i]
else:
link[u] = self.node[link[v]][i]
que.append(u)
mod = 998244353
N, L = map(int, input().split())
AC = AhoCorasick()
for i in range(N):
AC.add([ord(c) - ord("a") for c in input()], i)
AC.build()
m = len(AC.node)
dp = [[0] * m for i in range(1 << N)]
dp[0][0] = 1
for _ in range(L):
ndp = [[0] * m for i in range(1 << N)]
for bit in range(1 << N):
for v in range(m):
for i in range(26):
to = AC.node[v][i]
nbit = bit | AC.last[to]
ndp[nbit][to] = (dp[bit][v] + ndp[nbit][to]) % mod
dp = ndp
print(sum(dp[-1]) % mod)
posted:
last update: