Official

O - 整地クエリ/Flatten Query Editorial by yuto1115

解説

クエリの内容を要約すると以下の通りです。

  • クエリ \(1\) : \(k,d\ (=\pm 1)\) が与えられる。\(A_k\)\(d\) を足す。
  • クエリ \(2\) : \(x\) が与えられる。\(\displaystyle \sum_{i=1}^{N} |A_i-x|\) を求める。

\(A\) そのものを管理しているだけだとクエリ \(2\) を高速に処理できないので、\(A\) を昇順にソートして得られる列(\(B\) とおきます)を追加で管理することを考えます。簡単のため、\(B\) に関する以下の \(2\) 種類の関数を定義します。

  • \(\mathrm{lb}(x) := (B_i\geq x\) を満たす最小の \(i\)(存在しないならば \(N+1\)\()\)
  • \(\mathrm{sum}(l, r) := \displaystyle \sum_{i=l}^{r} B_i\)

このとき、クエリ \(2\) に対する答えは、\(\displaystyle \sum_{i=1}^{\mathrm{lb}(x)-1}(x-B_i) + \sum_{i=\mathrm{lb}(x)}^{N}(B_i-x) = x(2\cdot\mathrm{lb}(x)-2-N)- \mathrm{sum}(1,\mathrm{lb}(x)-1)+\mathrm{sum}(\mathrm{lb}(x),N)\) というように \(\mathrm{lb,sum}\) を用いて求めることができます。

次に、クエリ \(1\) について考えます。\(B\) 内のどこに \(A_k\) が位置しているかを求めてそこに \(d\) を足せばよいのですが、\(B\) が常に昇順に並んでいる状態にするためには多少の工夫が必要です。具体的には以下の通りです。

  • \(l=\mathrm{lb}(A_k),r=\mathrm{lb}(A_k+1)-1\) とおく。(すなわち、\(B\) 内で \(A_k\) と同じ値が現れる添字の範囲が \([l,r]\) である。)
  • \(d=1\) ならば、\(B_r\)\(1\) を足す。\(d=-1\) ならば、\(B_l\)\(-1\) を足す。

まとめると、\(A\) を昇順にソートして得られる列 \(B\) に対し、以下の \(3\) 種類の操作を高速に行えればよいです。

  • ある \(x\) に対し、\(\mathrm{lb}(x)\) を求める。
  • ある \(l,r\) に対し、\(\mathrm{sum}(l,r)\) を求める。
  • ある \(i, d\) に対し、\(b_i\)\(d\) を足す。

これは、各ノードに範囲内の要素の総和と最大値を持たせた segment tree によって各操作 \(O(\log N)\) で処理できます。

よって、計算量 \(O((N+Q)\log N)\) で本問題を解くことができました。

実装例 (C++) :

#include<bits/stdc++.h>
#include<atcoder/segtree>

using namespace std;
using namespace atcoder;

using ll = long long;

// {sum, max}
using S = pair<ll, int>;

S op(S a, S b) {
    return {a.first + b.first, max(a.second, b.second)};
}

S e() {
    return {0, -(int) (2e9)};
}

int main() {
    int n;
    cin >> n;
    vector<int> a(n);
    for (int &i: a) cin >> i;
    
    auto b = a;
    sort(b.begin(), b.end());
    vector<S> init(n);
    for (int i = 0; i < n; i++) init[i] = {b[i], b[i]};
    segtree<S, op, e> st(init);
    
    // b_i >= x を満たす最小の i
    auto lb = [&](int x) -> int {
        return st.max_right(0, [&](auto val) { return val.second < x; });
    };
    
    int q;
    cin >> q;
    while (q--) {
        int t;
        cin >> t;
        if (t == 1) {
            int k, d;
            cin >> k >> d;
            --k;
            int l = lb(a[k]);
            int r = lb(a[k] + 1) - 1;
            if (d == 1) st.set(r, {a[k] + 1, a[k] + 1});
            else st.set(l, {a[k] - 1, a[k] - 1});
            a[k] += d;
        } else {
            int x;
            cin >> x;
            int p = lb(x);
            ll ans = (ll) x * (2 * p - n) - st.prod(0, p).first + st.prod(p, n).first;
            cout << ans << '\n';
        }
    }
}

posted:
last update: