Official

F - Merge Sets Editorial by en_translator


First, the definition of \(f(A,B)\) can be rephrased as follows.

  • \(\displaystyle f(A,B) = \sum_{i \in A} (\text{The number of integers no greater than }i\text{ that are contained in }A \cup B).\)

Using this fact, the expression can be deformed as

\[ \begin{aligned} \sum_{1 \leq i < j \leq N} f(S_i,S_j) &= \sum_{1 \leq i < j \leq N} \sum_{k=1}^{M} (\text{The number of integers no greater than } A_{i,k}\text{ that are contained in }S_i \cup S_j )\\ &= \sum_{1 \leq i < j \leq N} \left( \sum_{k=1}^{M} (\text{The number of integers no greater than } A_{i,k}\text{ that are contained in }S_i ) + \sum_{k=1}^{M} (\text{The number of integers no greater than } A_{i,k}\text{ that are contained in }S_j ) \right)\\ &= \sum_{1 \leq i < j \leq N} \left( \frac{M(M+1)}{2} + \sum_{k=1}^{M} (\text{The number of integers no greater than } A_{i,k}\text{ that are contained in }S_j ) \right)\\ &= \frac{M(M+1)}{2} \cdot \frac{N(N-1)}{2} + \sum_{1 \leq i < j \leq N}\sum_{k=1}^{M} (\text{The number of integers no greater than } A_{i,k}\text{ that are contained in }S_j )\\ &= \frac{M(M+1)}{2} \cdot \frac{N(N-1)}{2} + \sum_{i=1}^{N}\sum_{k=1}^{M} (\text{The number of integers no greater than } A_{i,k}\text{ that are contained in }S_{i+1}\cup \dots \cup S_N ).\\ \end{aligned} \]

The double sigma in the latter part can be found in a total of \(O(NM \log NM)\) time by scanning \(A_{i,j}\) in descending order of \(i\) while using a Fenwick tree. Note that you need to apply a coordinate compression on \(A\) first.

Sample code (C++):

#include<bits/stdc++.h>
#include<atcoder/fenwicktree>

using namespace std;
using namespace atcoder;

using ll = long long;

int main() {
    int n, m;
    cin >> n >> m;
    vector a(n, vector<int>(m));
    vector<int> v;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            cin >> a[i][j];
            v.push_back(a[i][j]);
        }
    }
    sort(v.begin(), v.end());
    
    ll ans = (ll) m * (m + 1) / 2 * n * (n - 1) / 2;
    fenwick_tree<int> fw(n * m);
    for (int i = n - 1; i >= 0; i--) {
        for (int j = 0; j < m; j++) {
            a[i][j] = lower_bound(v.begin(), v.end(), a[i][j]) - v.begin();
            ans += fw.sum(0, a[i][j]);
        }
        for (int j = 0; j < m; j++) {
            fw.add(a[i][j], 1);
        }
    }
    cout << ans << endl;
}

posted:
last update: