Official

F - Manhattan Cafe Editorial by en_translator


Consider a DP (Dynamic Programming) in which the elements are determined from the first element.
Two parameters are required for DP, (the distance between \(p\) and \(r\)) and (the distance between \(q\) and \(r\)). Define the state indexed by these two values as follows:

  • \(dp_{t, i, j} :=\) (the number of ways to determine the first \(t\) components of \(r\) such that \(\sum_{n=1}^t |p_n - r_n| = i, \sum_{n=1}^t |q_n - r_n| = j\)).

Then, the answer to the problem is expressed by:

\[\sum_{i=0}^D \sum_{j=0}^D dp_{N, i, j}.\]

The initial states and transitions of DP is described by the following C++-ish code:

#define rep(i, n) for (int i = 0; i < (int)(n); i++)

constexpr int MAX = 1001;
atcoder::modint998244353 dp[MAX][MAX], nx[MAX][MAX];
dp[0][0] = 1;
rep(t, N) {
  rep(i, D + 1) rep(j, D + 1) nx[i][j] = 0;
  ll pt = p[t], qt = q[t];
  for (int rt = -MAX * 2; rt <= MAX * 2; rt++) {
    int di = abs(pt - rt);
    int dj = abs(qt - rt);
    if (min(di, dj) > D) continue;
    for (int i = 0, ie = D - di; i <= ie; i++) {
      for (int j = 0, je = D - dj; j <= je; j++) {
        nx[i + di][j + dj] += dp[i][j];
      }
    }
  }
  swap(dp, nx);
}

The problem has been solved… but wait, really? The time complexity matters! The DP above costs \(\mathrm{O}(ND^3)\) time, so it will result in TLE (Time Limit Exceeded).

Instead, note the property of the transition of DP. The pairs \((di, dj)\) in the code above is classified into three groups, where \(s = \vert p_t - q_t \vert\):

  • \((s,0), (s-1, 1), (s-2, 2), \dots, (1, s-1), (0, s)\)
  • \((s+1, 1), (s+2, 2), (s+3, 3), \dots\)
  • \((1, s+1), (2, s+2), (3, s+3), \dots\)

When the elements of each group are plotted on a grid, they line up in a line. Therefore, by computing diagonal cumulative sums of the DP table, transition for each \((i, j)\) is reduced from \(\mathrm{O}(D)\) to \(\mathrm{O}(1)\), reducing the total time complexity from \(\mathrm{O}(ND^3)\) to \(\mathrm{O}(N D^2)\). (For more details, please refer to the sample code.) Therefore, the problem has been solved fast enough.

  • Sample code (PyPy)
mod = 998244353
N, D = map(int, input().split())
p = list(map(int, input().split()))
q = list(map(int, input().split()))

dp = [[0 for _ in range(D + 1)] for _ in range(D + 1)]
dp[0][0] = 1

for n in range(N):
  pn, qn = p[n], q[n]
  s = abs(pn - qn)
  nxt = [[0 for _ in range(D + 1)] for _ in range(D + 1)]

  dp2 = [[0 for _ in range(D + 1)] for _ in range(D + 1)]
  for i in range(D + 1):
    for j in range(D + 1):
      dp2[i][j] = dp[i][j]
      if i != 0 and j != D:
        dp2[i][j] += dp2[i - 1][j + 1]
        dp2[i][j] %= mod

  for i in range(D + 1):
    for j in range(D + 1):
      si = i
      sj = j - s
      if sj < 0:
        si += sj
        sj = 0
      if 0 <= si <= D and 0 <= sj <= D:
        nxt[i][j] += dp2[si][sj]
        nxt[i][j] %= mod
      ti = i - (s + 1)
      tj = j + 1
      if 0 <= ti <= D and 0 <= tj <= D:
        nxt[i][j] -= dp2[ti][tj]
        nxt[i][j] %= mod

  dp3 = [[0 for _ in range(D + 1)]for _ in range(D + 1)]
  for i in range(D + 1):
    for j in range(D + 1):
      dp3[i][j] = dp[i][j]
      if i != 0 and j != 0:
        dp3[i][j] += dp3[i - 1][j - 1]
        dp3[i][j] %= mod
      if i + 1 <= D and j + s + 1 <= D:
        nxt[i + 1][j + s + 1] += dp3[i][j]
        nxt[i + 1][j + s + 1] %= mod
      if i + s + 1 <= D and j + 1 <= D:
        nxt[i + s + 1][j + 1] += dp3[i][j]
        nxt[i + s + 1][j + 1] %= mod

  dp = nxt

print(sum((sum(v) for v in dp)) % mod)

For readability, the sample code above has a bad constant factor. If you want to make it faster, you can replace

dp = [[0 for _ in range(D + 1)] for _ in range(D + 1)]

with

dp = [0] * (D + 1) * (D + 1)
# dp[i * (D + 1) + j] refers to the (i, j)-th component

making it twice or three times as fast as the original code, so we recommend to implement so.

  • The time limit is challenging for slow languages because a naive algorithm in C++ (of time complexity \(\mathrm{O}(\frac{N D^3}{s})\)) runs in around 10 seconds even with a rough implementation, and we wanted to make such solution TLE (Time Limit Exceeded). (Here, \(s\) denotes the number of elements that can be simultaneously processed by parallelization with an automatic vectorization by pragma comment etc.; \(s=8\) or \(16\))
    • Careful readers may consider “we can make \(D\) larger, right?” (because the ratio of the execution times of the two solutions depends on \(D/s\)), but making \(D\) larger makes \(N\) smaller, giving a way to \(\mathrm{O}(N^2 D^2)\) solution.

posted:
last update: