Official

C - Merge Sequences Editorial by MMNMM


\(C=(C _ 1,C _ 2,\ldots,C _ {N+M})\) は、列を連結してソートすることで \(O((N+M)\log(N+M))\) 時間で求めることができます。 これで十分高速ですが、より高速に \(C\) を求めるアルゴリズムもあります。

ソート列をマージする高速なアルゴリズム( \(O(N+M)\) 時間)

次のアルゴリズムを用いると \(C\) を \(O(N+M)\) 時間で求めることができます。

  • はじめ \(i=j=k=1\) として、\(i\leq N+M\) である限り次を繰り返す。
    • \(k>M\) であるか、\(j\leq N\) かつ \(A _ j\leq B _ k\) のとき、
      • \(C _ i=A _ j\) とし、\(i,j\) を \(1\) 増やす。
    • そうでないとき、
      • \(C _ i=B _ k\) とし、\(i,k\) を \(1\) 増やす。

C++ では、std::merge 関数でこれを行うことができます。


\(C\) が具体的に得られたら、(必要ならば適切な前計算ののち)\(C _ i\) の値から \(i\) を計算することが \(O(\log(N+M))\) 時間や expected \(O(1)\) 時間で可能です。 前者は \(C\) に対する二分探索や(平衡二分探索木などによる)連想配列、後者は(ハッシュマップによる)連想配列で実現できます。

よって、この問題を \(O((N+M)\log(N+M))\) 時間や expected \(O(N+M)\) 時間で解くことができました。

ソート列のマージの際に答えを計算したり、適切な情報と組にしてマージすることで worst \(O(N+M)\) 時間で解くこともできます。

実装例は以下のようになります。

  • python による解法
N, M = map(int, input().split())
A = list(map(int, input().split()))
B = list(map(int, input().split()))

C = sorted(A + B)

get_index = {c: i + 1 for i, c in enumerate(C)}

for a in A:
    print(get_index[a])

for b in B:
    print(get_index[b])
  • C++ による \(O((N+M)\log(N+M))\) 時間の解法
#include <iostream>
#include <vector>
#include <algorithm>

int main() {
    using namespace std;

    unsigned N, M;
    cin >> N >> M;

    vector<unsigned> A(N), B(M);
    for (auto &&a : A) cin >> a;
    for (auto &&b : B) cin >> b;

    vector<unsigned> C(N + M);
    merge(begin(A), end(A), begin(B), end(B), begin(C));

    for (const auto a : A)
        cout << lower_bound(begin(C), end(C), a) - begin(C) + 1 << " ";
    cout << endl;

    for (const auto b : B)
        cout << lower_bound(begin(C), end(C), b) - begin(C) + 1 << " ";
    cout << endl;

    return 0;
}
  • C++ による worst \(O(N+M)\) 時間の解法
#include <iostream>
#include <vector>
#include <utility>
#include <algorithm>

int main() {
    using namespace std;

    unsigned N, M;
    cin >> N >> M;

    vector<pair<unsigned, unsigned>> A(N), B(M);
    for (auto&&[a, _] : A) cin >> a;
    for (auto&&[b, _] : B) cin >> b;
    for (unsigned i{}; i < N; ++i) A[i].second = i;
    for (unsigned i{}; i < M; ++i) B[i].second = i + N;

    vector<pair<unsigned, unsigned>> C(N + M);
    merge(begin(A), end(A), begin(B), end(B), begin(C));

    vector<unsigned> ans(N + M);
    for (unsigned i{}; i < N + M; ++i) ans[C[i].second] = i;

    for (unsigned i{}; i < N + M; ++i) cout << ans[i] + 1 << " ";

    return 0;
}

posted:
last update: