Official

E - Clamp Editorial by yuto1115

解説

まず、\(l > r\) であるようなクエリについては、どんな整数 \(x\) に対しても \( \max(l, \min(r, x))=l\) が成り立つため、単に \(l\times N\) を出力すればよいです。以下 \(l\leq r\) を仮定します。

\(\max(l, \min(r, A_i))\) という式の値について観察すると、以下のように \(A_i\) の値の範囲に応じた単純な形に書き表せることがわかります。

\[ \max(l, \min(r, A_i))= \left\{ \begin{array}{cl} l & (A_i < l) \\ A_i & (l \leq A_i \leq r) \\ r & (r < A_i) \end{array} \right. \]

したがって、\(A_i=j\) を満たす \(i\) の個数を \(C_j\) とおき、また数列の要素として現れうる最大値を \(K\) とおくと、

\[ \displaystyle \begin{aligned} \sum_{1\leq i\leq N} \max(l, \min(r, A_i)) &= \sum_{1\leq j\leq K} C_j \cdot \max(l, \min(r, j)) \\ &= \sum_{1\leq j\leq l-1} C_j \cdot l + \sum_{l\leq j\leq r} C_j \cdot j + \sum_{r+1\leq j\leq K} C_j \cdot r \\ &= \left(\sum_{1\leq j\leq l-1} C_j\right) \cdot l + \sum_{l\leq j\leq r} C_j \cdot j +\left (\sum_{r+1\leq j\leq K} C_j\right) \cdot r \\ \end{aligned} \]

というように変形することができます。

数列 \(A\) に加えて数列 \(C=(C_1,C_2,\dots,C_K)\) を管理することを考えます。与えられる更新クエリおよび求値クエリを高速に処理するためには、\(C\) に対する以下の操作を高速に処理できればよいです:

  • ある \(j\) に対し、\(C_j\) の値を変更する。
  • ある \(l, r\) に対し、\(\displaystyle \sum_{l\leq j\leq r}C_j\) の値を計算する。
  • ある \(l, r\) に対し、\(\displaystyle \sum_{l\leq j\leq r}C_j\cdot j\) の値を計算する。

これはセグメント木 (Segment Tree) と呼ばれるデータ構造によって実現することが可能です。詳細は下記の実装例および AC Library (Segtree) のドキュメント を参考にしてください。

計算量は \(O(N+Q\log K)\) です。

実装例 (C++) :

#include <bits/stdc++.h>
#include <atcoder/segtree>

using namespace std;
using namespace atcoder;

using ll = long long;

/*
  First element: Sum of C_j
  Second element: Sum of C_j \dot j
*/
using S = pair<int, ll>;

S op(const S &a, const S &b) {
    return {a.first + b.first, a.second + b.second};
}

S e() {
    return {0, 0};
}

const int C = 500010;
segtree<S, op, e> st(C);

void add(int x) {
    auto now = st.get(x);
    now.first += 1;
    now.second += x;
    st.set(x, now);
}

void del(int x) {
    auto now = st.get(x);
    now.first -= 1;
    now.second -= x;
    st.set(x, now);
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, q;
    cin >> n >> q;
    vector<int> a(n);
    for (int &i: a) {
        cin >> i;
        add(i);
    }
    while (q--) {
        int t;
        cin >> t;
        if (t == 1) {
            int x, y;
            cin >> x >> y;
            --x;
            del(a[x]);
            a[x] = y;
            add(a[x]);
        } else {
            int l, r;
            cin >> l >> r;
            ll ans = 0;
            if (l > r) {
                ans = (ll) l * n;
            } else {
                ans += (ll) l * st.prod(0, l).first;
                ans += st.prod(l, r + 1).second;
                ans += (ll) r * st.prod(r + 1, C).first;
            }
            cout << ans << '\n';
        }
    }
}

posted:
last update: