Official

E - Best Performances Editorial by physics0523


この問題には様々な解法がありますが、ここでは「同じデータ構造を \(2\) つ使う」方針を使います。

今回はデータ構造として multiset(多重集合を管理するデータ構造) を用います。使う機能は以下の通りなので、同等の操作ができるならどれを使っても構いません。

  • 多重集合を管理する。即ち、同一の値を複数保持することができる。
  • 最大値(または最小値、またはその両方) にアクセスできる。
  • 特定の値を \(1\) つ削除することができる。

「大きい方から \(K\) 個の値を管理する multiset \(X\) 」「それ以外の値を管理する multiset \(Y\) 」「 \(X\) の総和 \(s\) 」という \(3\) つの要素を管理していくことを考えます。

最初 \(X\) は大きい方から \(K\) 個の値 (今回は \(K\) 個の \(0\) ) 、 \(Y\)\(A\) 内のそれ以外の値 (今回は \(N-K\) 個の \(0\) ) 、 \(s\)\(X\) の総和と初期化します。

以下の操作 balance, add, erase を定義します。これは \(2\) つの multiset で実装できます。各操作を実装していく道中で \(s\) も更新していきます。

  • balance\(X\) の要素数が \(K\) 個でないなら、 \(K\) 個になるまで \(Y\) の要素のうち最大のものを \(X\) に移動することを繰り返す。そのうえで、 ( \(X\) の最小値) \(<\) ( \(Y\) の最大値) が満たされる限りこれらを入れ替えることを繰り返す。
  • add\(Y\) に特定の値 \(v\) を追加した上で balance する。
  • erase … 特定の値 \(v\) を削除する。 \(X\)\(v\) が含まれるならそれを消し、そうでないなら \(Y\) から \(v\) を消す。その後 balance する。

(おまけ: 今回の操作では \(X\) の要素数が \(K\) を上回ることはないですが、 \(X\) の要素数が \(K\) を上回るような更新があるようなケースでも \(X\) のうち小さい要素を \(Y\) に移動させれば上手く取り扱えます。)

\(A_i\) を変更する時は以下の操作をかければよいです。

  • 新しい \(A_i\)add する。
  • その後、元の \(A_i\)erase する。

(この順番は逆でも良いですが \(X\)\(Y\) の大きさの合計が \(K\) を下回るケースや \(N=1\) の時に \(X,Y\) が双方空になるケースなどがあり実装がしづらくなります。なお、今回の問題では番兵として \(0\) をたくさん入れるという対処法もあります。)

今回の問題では、 balance 内での swap の回数が \(O(1)\) であると見積もられるので、全体の計算量は \(O(Q \log N)\) となります。

この解法は他の問題にも応用できますが、実装上の注意として、操作の道中で \(X\)\(Y\) が空集合になったりする可能性があることに注意してください。

実装例 (C++):

#include<bits/stdc++.h>

using namespace std;

int k;
multiset<int> x,y;
long long s;

void balance(){
  while(x.size()<k){
    auto iy=y.end();iy--;
    x.insert((*iy));
    s+=(*iy);
    y.erase(iy);
  }
  if(x.empty() || y.empty()){return;}

  while(1){
    auto ix=x.begin();
    auto iy=y.end();iy--;
    int ex=(*ix);
    int ey=(*iy);
    if(ex >= ey){break;}
    s+=(ey-ex);
    x.erase(ix);
    y.erase(iy);
    x.insert(ey);
    y.insert(ex);
  }
}

void add(int v){
  y.insert(v);
  balance();
}

void erase(int v){
  auto ix=x.find(v);
  if(ix!=x.end()){ s-=v; x.erase(ix); }
  else{ y.erase(y.find(v)); }
  balance();
}

int main(){
  int n;
  cin >> n >> k;
  vector<int> a(n,0);
  for(int i=0;i<k;i++){ x.insert(0); }
  for(int i=k;i<n;i++){ y.insert(0); }
  s=0;
  
  int q;
  cin >> q;
  while(q>0){
    q--;
    int p,w;
    cin >> p >> w;
    p--;
    add(w);
    erase(a[p]);
    a[p]=w;
    cout << s << "\n";
  }
  return 0;
}

posted:
last update: