Official

D - Delete Range Mex Editorial by evima


First, let’s consider whether it is possible to transform \(B = P\) to \(B = A\). Since there are no duplicate elements in \(B\), once an integer \(x\) is removed from \(B\), it becomes impossible to remove any number greater than or equal to \(x\) thereafter. Therefore, we need to remove elements not included in \(A\) in descending order.

Additionally, when removing \(x\) from \(B\), we should choose the maximal interval that does not include \(x\) and whose \(\mathrm{mex}\) is \(x\). Since there are no duplicate elements in \(B\), there are at most two candidate intervals, and we should choose the one that includes all elements less than \(x\).

In the end, we need to count the number of \(P\) that satisfy all the following conditions:

  • \(A\) exists as a subsequence of \(P\).
  • For all integers \(x\) not included in \(A\) and between \(0\) and \(N-1\), one of the following holds:
    • The indices of all integers from \(0\) to \(x-1\) are smaller than the index of \(x\).
    • The indices of all integers from \(0\) to \(x-1\) are larger than the index of \(x\).

We consider inserting integers not included in \(A\) in ascending order to satisfy the conditions.

The first condition is always satisfied.

For the second condition, when inserting an integer \(x\), we should insert it either to the left of the leftmost integer less than \(x\) or to the right of the rightmost integer less than \(x\).

Based on this, we consider dynamic programming.

Let \(dp[i][l][r]\) be the number of ways to insert integers up to \(i\) such that there are \(l\) integers to the left of the leftmost integer at most \(i\), and \(r\) integers to the right of the rightmost integer at most \(i\).

The answer to the problem is \(dp[N-1][0][0]\).

If \(0\) is included in \(A\), initialize \(dp[0][k-1][M-k] = 1\) using \(k\) such that \(A_k = 0\).

If \(0\) is not included in \(A\), initialize \(dp[0][i][M-i] = 1\) for \(i = 0, 1, \dots, M\).

Compute \(dp[i]\) for \(i = 1, 2, \dots, N-1\) in order.

When \(i\) is included in \(A\)

Using \(k\) such that \(A_k = i\), perform update for all \(l, r\) as follows:

\[dp[i][\min(l,k-1)][\min(r,M-k)] += dp[i-1][l][r]\]

This update can be done in \(O(M^2)\) time.

When \(i\) is not included in \(A\)

The following holds:

\[dp[i][l][r] = \sum_{L=l}^{M} dp[i-1][L][r] + \sum_{R=r}^{M} dp[i-1][l][R]\]

If done naively, the time complexity would be \(O(M^3)\), but using a cumulative sum approach, it can be computed in \(O(M^2)\) time.

Therefore, our solution has a overall time complexity of \(O(NM^2)\).

In the implementation examples below, the DP table is initialized as described above, but directly computing \(dp[\min(A)]\) using binomial coefficients reduces the range of the table that needs to be considered, allowing for a constant factor speedup.

Implementation example with constant factor speedup:

C++ 34ms

PyPy 184 ms

Implementation without constant factor speedup:

C++ 105ms

Python

MOD = 998244353

# input etc
N , M = map(int, input().split())
A = list(map(int, input().split()))
B = [-1] * N
for i in range(M):
    B[A[i]] = i

# init
dp = [[0 for j in range(M + 1)] for i in range(M + 1)]
n_dp = [[0 for j in range(M + 1)] for i in range(M + 1)]
# 0 in A
if B[0] != -1:
    dp[B[0]][M - 1 - B[0]] = 1

# 0 not in A
else:
    for i in range(M + 1):
        dp[i][M - i] = 1

# dp
for i in range(1, N):
    
    # i in A
    if B[i] != -1:
        for l in range(M + 1):
            for r in range(M + 1 - l):
                a = min(l, B[i])
                b = min(r, M - 1 - B[i])
                if a != l or b != r:
                    dp[a][b] = (dp[a][b] + dp[l][r]) % MOD
                    dp[l][r] = 0
    
    # i not in A
    else:
        # copy dp
        for l in range(M + 1):
            for r in range(M + 1 - l):
                n_dp[l][r] = dp[l][r]
        #   dp[l][r] = dp[l][r] + dp[l + 1][r] + dp[l + 2][r] + ...
        # n_dp[l][r] = dp[l][r] + dp[l][r + 1] + dp[l][r + 2] + ...
        for l in range(M, -1, -1):
            for r in range(M - l, -1, -1):
                if l != M:
                    dp[l][r] = (dp[l][r] + dp[l + 1][r]) % MOD
                if r != M:
                    n_dp[l][r] = (n_dp[l][r] + n_dp[l][r + 1]) % MOD

        # add
        for l in range(M + 1):
            for r in range(M + 1 - l):
                dp[l][r] = (dp[l][r] + n_dp[l][r]) % MOD


# output
print(dp[0][0])

posted:
last update: