E - RLE Editorial by Nyaan


(ユーザー解説です)

別解として、imos 法をしながら DP テーブルを更新していくアルゴリズムを紹介します。

まずは O(N3)\mathrm{O}(N^3) の解法を説明します。

  • dp[i][j]:=dp[i][j] := 変換後の文字列の長さが ii で、変換前の文字列の長さが jj であるような文字列の通り数

と DP テーブルを定義して、いわゆる「配る DP 」を行うことを考えます。すると、「同じ文字が djdj 個連なった文字列を末尾に加える」という操作は di:=di := 1 + (djdj の桁数) として

dp[i+di][j+dj]dp[i+di][j+dj]+dp[i][j](26 if i=0 else 25)dp[i+di][j+dj] \gets dp[i+di][j+dj] + dp[i][j] * (26 \text{ if i=0 else } 25)

になります。よって i,j,dji,j,dj の3 重ループをナイーブに実装すると O(N3)\mathrm{O}(N^3) で解くことができます。

少し工夫してみましょう。djdj を 10 進表記したときの桁数 kk ごとに遷移をまとめてみると、次のような手続きで解くことができます。

  • はじめ、dp[0][0]=1dp[0][0] = 1 とする。
  • iii=0,1,,N1i=0,1,\dots,N-1 の順に for- loop を行う。
    • jjj=0,1,,N1j=0,1,\dots,N-1 の順に for-loop を行う。
      • (後ろに加える文字列の長さ)の桁数が k=1,2,k=1,2,\dots の場合を考える。
      • すると、kk を固定した時 djdj10kdj<min(10k+1,Nj+1)10^k \leq dj \lt \min(10^{k+1}, N-j+1) の範囲になる。
      • 上に書いた djdj の範囲で for-loop を行う。
        • dp[i+k+1][j+dj]dp[i+k+1][j+dj] に加算を行う。

コードに起こすと次のようになります。(off-by-one error を回避するために配列や遷移先を少し大きめに取っています)

Copy
  1. N, P = map(int, input().split())
  2. dp = [[0] * (N + 10) for _ in range(N + 10)]
  3. dp[0][0] = 1
  4. for i in range(N):
  5. for j in range(N):
  6. ways = dp[i][j] * (26 if i == 0 else 25) % P
  7. L, k = 1, 1
  8. while L <= N - j:
  9. R = min(L * 10, N - j + 3)
  10. di = k + 1
  11. for dj in range(L, R):
  12. dp[i + di][j + dj] = (dp[i + di][j + dj] + ways) % P
  13. L, k = L * 10, k + 1
  14. print(sum([dp[i][N] for i in range(N)]) % P)
N, P = map(int, input().split())
dp = [[0] * (N + 10) for _ in range(N + 10)]
dp[0][0] = 1
for i in range(N):
  for j in range(N):
    ways = dp[i][j] * (26 if i == 0 else 25) % P
    L, k = 1, 1
    while L <= N - j:
      R = min(L * 10, N - j + 3)
      di = k + 1
      for dj in range(L, R):
        dp[i + di][j + dj] = (dp[i + di][j + dj] + ways) % P
      L, k = L * 10, k + 1

print(sum([dp[i][N] for i in range(N)]) % P)

ここで、11 行目~ 12 行目の最も内側のループがボトルネックになっているのでこれをうまく計算量改善できないか考えましょう。

Copy
  1. for dj in range(L, R):
  2. dp[i + di][j + dj] = (dp[i + di][j + dj] + ways) % P
for dj in range(L, R):
  dp[i + di][j + dj] = (dp[i + di][j + dj] + ways) % P

該当の場所を取り出しました。これは [L,R)[L, R)ways\mathrm{ways} を加算する操作を行っています。このような区間加算は imos 法を使うと以下のように 22 点更新に言い換えられます。

Copy
  1. dp[i + di][j + L] = (dp[i + di][j + L] + ways) % P
  2. dp[i + di][j + R] = (dp[i + di][j + R] - ways) % P
dp[i + di][j + L] = (dp[i + di][j + L] + ways) % P
dp[i + di][j + R] = (dp[i + di][j + R] - ways) % P

よって、「値を imos 法で更新して、値を取得する必要がある時に累積和で値を復元する」という少し工夫したアルゴリズムを使えば計算量を落とすことができます。

手続きは次のようになります。

  • はじめ、dp[0][0]=1dp[0][0] = 1 とする。
  • iii=0,1,,N1i=0,1,\dots,N-1 の順に for- loop を行う。
    • i0i \neq 0 の場合、データが imos 法の状態で保存されているので累積和を取ることで正しい値に復元する。
    • jjj=0,1,,N1j=0,1,\dots,N-1 の順に for- loop を行う。
      • 後ろに加える文字列の長さを桁数が k=1,2,k=1,2,\dots の場合を考える。
      • すると、kk を固定した時 djdj10kdj<min(10k+1,Nj+1)10^k \leq dj \lt \min(10^{k+1}, N-j+1) の範囲になる。
      • imos 法の要領で両端の値を更新する。

コードに起こすと次のようになります。

Copy
  1. N, P = map(int, input().split())
  2. dp = [[0] * (N + 10) for _ in range(N + 10)]
  3. dp[0][0] = 1
  4. for i in range(N):
  5. if i != 0:
  6. for j in range(N + 3):
  7. dp[i][j + 1] = (dp[i][j] + dp[i][j + 1]) % P
  8. for j in range(N):
  9. ways = dp[i][j] * (26 if i == 0 else 25) % P
  10. L, k = 1, 1
  11. while L <= N - j:
  12. R = min(L * 10, N - j + 3)
  13. di = k + 1
  14. dp[i + di][j + L] = (dp[i + di][j + L] + ways) % P
  15. dp[i + di][j + R] = (dp[i + di][j + R] - ways) % P
  16. L, k = L * 10, k + 1
  17. print(sum([dp[i][N] for i in range(N)]) % P)
N, P = map(int, input().split())
dp = [[0] * (N + 10) for _ in range(N + 10)]
dp[0][0] = 1
for i in range(N):
  if i != 0:
    for j in range(N + 3):
      dp[i][j + 1] = (dp[i][j] + dp[i][j + 1]) % P
  for j in range(N):
    ways = dp[i][j] * (26 if i == 0 else 25) % P
    L, k = 1, 1
    while L <= N - j:
      R = min(L * 10, N - j + 3)
      di = k + 1
      dp[i + di][j + L] = (dp[i + di][j + L] + ways) % P
      dp[i + di][j + R] = (dp[i + di][j + R] - ways) % P
      L, k = L * 10, k + 1

print(sum([dp[i][N] for i in range(N)]) % P)

計算量は i,ji,j のループが NN 回、kkO(log10N)\mathrm{O}(\log_{10} N) 回程度回るので全体で O(N2logN)\mathrm{O}(N^2 \log N) となります。

posted:
last update:



2025-03-29 (Sat)
09:55:59 +00:00