公式

F - Concat (2nd) 解説 by physics0523

別解

まず、 \(1\) つ目の解説の前半をお読みください。

文字列 \(X,Y\) の比較に時間計算量 \(O(|X|+|Y|)\) かけたとしても、各要素に触れる回数をソート \(1\) 段あたり \(O(\log N)\) 回に抑え、ソート全体を \(O(\log N)\) 段の再帰で完了させることができれば、ソートを時間計算量 \(O(\Sigma |S_i| \log^2 N)\) で完了させることができます。

このようなソートは、マージソートに指数探索を組み合わせることで実現可能です。
具体的には、以下の手続きで実現可能です。

  • マージソートと同様に、配列を \(A,B\)\(2\) 等分し、先に \(A,B\) についてソートを完了させる。
  • その後、 \(A,B\) のマージを行う。マージが完了するまで、以下の手続きを繰り返す。
    • \(2\) つの配列を、交互に “kick側” として指定する。 “kick側” に指定された配列を \(X\) 、そうでない配列を \(Y\) とする。
    • ここでの目的は、 \(X\) の先頭の要素についてソートを完了させることです。ここで、 \(Y\) からあとどれだけの要素を前に入れればよいかを指数探索で求めます。
      • 指数探索の第 \(1\) フェーズとして、 \(X\) の先頭の要素は \(Y\)\(2^k\) 項目とを \(k=0,1,\dots\) の昇順に比較していきます。
      • 現在、「 \(X\) の先頭の要素は \(Y\)\(2^{k-1}\) 項目から \(2^k\) 項目までのどこかに入る」というようなことが分かっています。この具体的な位置を二分探索するのが指数探索の第 \(2\) フェーズです。
      • この結果、 \(X\) の先頭の要素が入るところまでのソートが完了します。

詳細は実装例も参照してください。

このようにすることで、 \(1\) 段中に特定の要素に触れる回数を \(O(\log N)\) とすることができます。証明は次の通りです。

  • 指数探索の \(X\) の先頭として各要素に触れる回数は高々 \(O(\log N)\) 回です。

  • 指数探索の \(Y\) の要素として各要素に触れる回数を考えましょう。

    • 指数探索の第 \(1\) フェーズで、各要素を \(2^k\) 項目の要素として触れる回数は各 \(k\) ごとに高々 \(1\) 回です。
      • kick 側を交互に指定することから、次にこの配列をもう一度 \(Y\) に指定する時にはその特定の要素は真に前進しているからです。
    • 指数探索の第 \(2\) フェーズに入る前に \(Y\) の先頭 \(2^{k-1}\) 項のソートを完了させそれらを削除したと捉えると、同様の理由から、指数探索の第 \(2\) フェーズでも各要素を \(2^k\) 項目の要素として触れる回数は各 \(k\) ごとに高々 \(1\) 回です。

以上より、全ての場合で特定の要素に触れる回数が高々 \(O(\log N)\) 回であることが確認できたため、証明が完了しました。

ここからの解法は、 \(1\) つ目の解説の後半を参照してください。

このようにすることで、全体で時間計算量 \(O(\Sigma |S_i| \log^2 N)\) で解くことができました。

このソートアルゴリズムは、 Timsort のアルゴリズムのアイデアを利用しており、 Timsort にこれと同様の処理が含まれています。
よって、標準のソートを Timsort で実装する CPython や OpenJDK などの環境では、 \(2\) つの文字列 \(|X|,|Y|\) の比較に時間計算量 \(O(|X|+|Y|)\) かかる関数を比較関数として渡した場合でもソートの計算量についてあまり考えなくとも AC を取ることが出来ます。

実装例 (C++):

#include<bits/stdc++.h>

using namespace std;

template <class RandomAccessIterator, class Compare> constexpr void exponential_merge_sort(
  RandomAccessIterator first, RandomAccessIterator last, Compare comp){
  auto len=(last-first);
  if(len<=1){return;}

  auto mid=first+len/2;
  exponential_merge_sort(first,mid,comp);
  exponential_merge_sort(mid,last,comp);

  auto a=vector(first,mid);
  auto b=vector(mid,last);
  auto ai=a.begin();
  auto bi=b.begin();
  bool side=false;
  while(ai!=a.end() || bi!=b.end()){
    if(ai==a.end()){
      while(bi!=b.end()){
        (*first)=(*bi);
        first++; bi++;
      }
      break;
    }
    else if(bi==b.end()){
      while(ai!=a.end()){
        (*first)=(*ai);
        first++; ai++;
      }
      break;
    }
    side=(!side);
    if(side){
      // kick a.begin();
      long long d=1;
      while(true){
        auto it=ranges::next(bi,d-1,b.end());
        if(it==b.end()){break;}
        if(comp((*it),(*ai))){d<<=1;}
        else{break;}
      }
      d>>=1;
      long long c=d;
      while(d>=2){
        d>>=1;
        auto it=ranges::next(bi,c+d-1,b.end());
        if(it==b.end()){continue;}
        if(comp((*it),(*ai))){c+=d;}
      }
      while(c--){
        (*first)=(*bi);
        first++; bi++;
      }
      (*first)=(*ai);
      first++; ai++;
    }
    else{
      // kick b.begin();
      long long d=1;
      while(true){
        auto it=ranges::next(ai,d-1,a.end());
        if(it==a.end()){break;}
        if(comp((*it),(*bi))){d<<=1;}
        else{break;}
      }
      d>>=1;
      long long c=d;
      while(d>=2){
        d>>=1;
        auto it=ranges::next(ai,c+d-1,a.end());
        if(it==a.end()){continue;}
        if(comp((*it),(*bi))){c+=d;}
      }
      while(c--){
        (*first)=(*ai);
        first++; ai++;
      }
      (*first)=(*bi);
      first++; bi++;
    }
  }
}

bool comp(const string &x,const string &y){
  return (x+y < y+x);
}

string concat(vector<string> &s){
  string res="";
  for(auto &nx : s){ res+=nx; }
  return res;
}

int main(){
  int t;
  cin >> t;
  while(t--){
    int n;
    cin >> n;
    vector<string> s(n);
    for(auto &nx : s){cin >> nx;}
    exponential_merge_sort(s.begin(),s.end(),comp);

    if(n==2){
      cout << s[1]+s[0] << "\n";
      continue;
    }
    
    bool ok=false;
    for(int i=1;i<n;i++){
      if(s[i-1]+s[i] == s[i]+s[i-1]){ok=true; break;}
    }
    if(ok){
      cout << concat(s) << "\n";
      continue;
    }

    swap(s[n-1],s[n-2]);
    string c1=concat(s);
    swap(s[n-1],s[n-2]);
    swap(s[n-2],s[n-3]);
    cout << min(c1,concat(s)) << "\n";
  }
  return 0;
}

投稿日時:
最終更新: