Official

E - Level K Palindrome Editorial by tatyam


\(S\) をレベル \(K\) 以上 にすることを考えると、同じ文字でなければならないグループがあることがわかります。それぞれのグループに何が何文字含まれるかを (長さ \(26\) の配列などで管理して) 求め、最も多いものをそのグループの文字として採用することで、(多くの場合) 必要な最小の書き換え回数を達成できます。

ただし、レベル \(K\) の回文の中のレベル \(0\) の部分が回文になってしまう場合があります。こうなると、ちょうどレベル \(K\) という条件を満たさないので、レベル \(0\) の部分 (中央 \(1\) 字を除く) のうち \(1\) つを選び、そのグループの \(2\) 番目に多い文字を採用することで対応します。

実装では、\(S\) を半分に折り畳むことを \(K\) 回行い、残ったレベル \(0\) の部分に対して計算を行っています。
時間計算量は \(\Theta(|S|)\) です。

回答例 (C++)

#include <iostream>
#include <string>
#include <numeric>
#include <vector>
#include <valarray>
using namespace std;
void chmin(int& a, int b){ if(a > b) a = b; }

int main(){
    int K;
    string S;
    cin >> K >> S;
    int N = S.size();
    vector c(N, valarray(0, 26));
    for(int i = 0; i < N; i++) c[i][S[i] - 'a']++;

    if(K > 20 || (N << 1 >> K) == 0 || (N >> K) == 1) return puts("impossible") & 0;
    
    int ans = 0, x = 0;
    for(; x < K; x++){
        const int L = N / 2, odd = N % 2;
        for(int i = 0; i < L; i++){
            c[i] += c.back();
            c.pop_back();
        }
        if(odd){
            ans += (1 << x) - c.back().max();
            c.pop_back();
        }
        N = L;
    }
    if(N == 0){
        cout << ans << endl;
        return 0;
    }
    {
        const int L = N / 2, odd = N % 2;
        int pal = 0x3fffffff;  // 回文にしないためのコスト
        for(int i = 0; i < L; i++){
            auto& a = c[i];
            auto& b = c.back();
            vector<int> ia(26), ib(26);
            iota(ia.begin(), ia.end(), 0);
            iota(ib.begin(), ib.end(), 0);
            partial_sort(ia.begin(), ia.begin() + 2, ia.end(), [&](int x, int y){ return a[x] > a[y]; });
            partial_sort(ib.begin(), ib.begin() + 2, ib.end(), [&](int x, int y){ return b[x] > b[y]; });
            if(ia[0] != ib[0]) pal = 0;
            ans += (2 << x) - a[ia[0]] - b[ib[0]];
            chmin(pal, a[ia[0]] - a[ia[1]]);
            chmin(pal, b[ib[0]] - b[ib[1]]);
            c.pop_back();
        }
        if(odd) ans += (1 << x) - c.back().max();
        ans += pal;
    }
    cout << ans << endl;
}

回答例 (Python)

from collections import Counter

K = int(input())
S = input()

N = len(S)
if N << 1 >> K == 0 or N >> K == 1:
    exit(print("impossible"))

S = [Counter({c : 1}) for c in S]
ans = 0
for x in range(K):
    L = N // 2
    odd = N % 2
    for i in range(L):
        S[i] += S[-1]
        S.pop()
    if odd:
        ans += (1 << x) - S[-1].most_common(1)[0][1]
        S.pop()
    N = L

if N == 0:
    exit(print(ans))

x = K
L = N // 2
odd = N % 2
pal = 1 << 30  # 回文にしないためのコスト
for i in range(L):
    a = S[i].most_common(2)
    b = S[-1].most_common(2)
    S.pop()
    if a[0][0] != b[0][0]:
        pal = 0
    a = [a[0][1], a[1][1]] if len(a) == 2 else [a[0][1], 0]
    b = [b[0][1], b[1][1]] if len(b) == 2 else [b[0][1], 0]
    ans += (2 << x) - a[0] - b[0]
    if pal > a[0] - a[1]:
        pal = a[0] - a[1]
    if pal > b[0] - b[1]:
        pal = b[0] - b[1]
if odd:
    ans += (1 << x) - S[-1].most_common(1)[0][1]
ans += pal
print(ans)

posted:
last update: