公式

F - Range Division 解説 by sounansya


\(M=\log_2 \max A\ (=60)\) とします。

まず、二進数を 01 文字列に変換した以下の問題を考えます。

長さ \(N\) の 01 文字列の列 \(S=(S_1,S_2,\ldots,S_N)\) が与えられます。

\(S\) に対して以下の一連の操作を行うことを考えます:

  • 以下の条件を全て満たす整数の組 \((l,r)\) を選ぶ。
    • \(1\le l\le r\le N\)
    • \(S_l,S_{l+1},\ldots,S_r\) の末尾が全て同じ文字である
  • \(k=l,l+1,\ldots,r\) に対し、\(S_k\) の末尾の文字を消す。

\(S\) の要素を全て空文字列にするために必要な操作回数の最小値を求めてください。

この問題の答えは隣接する文字列を同時に消せるだけ消すことを考えることで \(\displaystyle \sum_{i=1}^N |S_i| - \sum_{i=1}^{N-1}\text{LCS}(S_i,S_{i+1})\) であることが分かります。

これを元にこの問題を解くことを考えます。


\(A_i=0\) の時も操作が可能であることを考えると、問題は以下に帰着されます:

\(A_i\) を二進数で文字列表記したものを \(S_i\) とし、\(S_i\) の先頭に 0\(k\) 文字つけた文字列を \(S_i^k\) とする。非負整数列 \(c=(c_1,c_2,\ldots,c_N)\) 全てに対する \(\displaystyle \sum_{i=1}^N \left(|S_i|+c_i \right)-\sum_{i=1}^{N-1} \text{LCS}(S_i^{c_i},S_{i+1}^{c_{i+1}})\) の最小値を求めよ。

したがって、以下の \(2\) つの計算がそれぞれできればこの問題を解くことができます。

  1. \(\text{LCS}(S_i^{c_i},S_{i+1}^{c_{i+1}})\) を高速に計算する。
  2. 1 を元に DP を用いて \(\displaystyle \sum_{i=1}^N \left(|S_i|+c_i \right)-\sum_{i=1}^{N-1} \text{LCS}(S_i^{c_i},S_{i+1}^{c_{i+1}})\) の最小値を求める。

最適解における \(c_i\) の最大値は \(60\) 以下とは限らないことに注意してください。


1. \(\text{LCS}(S_i^{c_i},S_{i+1}^{c_{i+1}})\) の計算

\(m=\min(c_i,c_{i+1})\) として、\(\text{LCS}(S_i^{c_i},S_{i+1}^{c_{i+1}})=\text{LCS}(S_i^{c_i-m},S_{i+1}^{c_{i+1}-m})+m\) が成り立ちます。\(c_i-m,c_{i+1}-m\) の少なくとも一方は \(0\) なので、各 \(k\) に対し \(\text{LCS}(S_i^k,S_{i+1})\)\(\text{LCS}(S_i,S_{i+1}^k)\) の値を計算すれば良いです。\(k\geq 60\) では常にこれらの値はそれぞれ同じであることを考えると、通常の LCS を DP で計算する方法でそれぞれ \(O(M^2)\) で計算することができます。

2. \(\displaystyle \sum_{i=1}^N \left(|S_i|+c_i \right)-\sum_{i=1}^{N-1} \text{LCS}(S_i^{c_i},S_{i+1}^{c_{i+1}})\) の最小値の計算

\(d[k][v]\) を「\(c_1\) から \(c_k\) まで決めて、\(c_k=v\) である場合の \(\displaystyle \sum_{i=1}^k (|S_i|+c_i)-\sum_{i=1}^{k-1}\text{LCS}(S_i^{c_i},S_{i+1}^{c_{i+1}})\) の最小値」と定義します。\(1\le k\le N,\ 0\le v\le\color{red} (N-1)M\color{black}\) の範囲を計算する必要があることに注意してください。この DP は \(O(N^3M^2)\) 時間で計算できます。


以上を適切に実装することでこの問題に正答することができます。計算量は \(O(N^3M^2)\) です。

DP を高速化することで \(O(N^2M^2)\) 時間で解くこともできます。

実装例(Python3)

import sys

input = sys.stdin.readline
M = 60
INF = 10**9


def f(s: str, t: str):
    s += "0" * M
    n, m = len(s), len(t)
    d = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(n):
        for j in range(m):
            if s[i] == t[j]:
                d[i + 1][j + 1] = d[i][j] + 1
            else:
                d[i + 1][j + 1] = max(d[i][j + 1], d[i + 1][j])
    ans = []
    for i in range(n - M, n + 1):
        ans.append(d[i][m])
    return ans


for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    s = ["" if x == 0 else bin(x)[2:][::-1] for x in a]
    L = M * (n + 1) + 1
    d = [j for j in range(L)]
    for i in range(n - 1):
        dd = [INF] * L
        r1 = f(s[i], s[i + 1])
        r2 = f(s[i + 1], s[i])
        for j1 in range(L):
            for j2 in range(L):
                if j1 < j2:
                    dd[j2] = min(dd[j2], d[j1] + j2 - j1 - r2[min(M, j2 - j1)])
                else:
                    dd[j2] = min(dd[j2], d[j1] + j2 - j2 - r1[min(M, j1 - j2)])
        d = dd
    print(sum(len(x) for x in s) + min(d))

原案:sounansya

投稿日時:
最終更新: