F - Double Sum Editorial by en_translator
This problem is a basic exercise of the algorithm called sweep line algorithm. If you could not solve this problem, we recommend you to learn the sweep line algorithm.
The sought double sum is
\[\sum_{i=1}^N \sum_{i < j} \max(A_j - A_i, 0).\]
Here, the contribution to the double some for a fixed \(i\) can be represented as
\[ \begin{aligned} &\sum_{i < j}\max(A_j - A_i, 0) \\ &= \sum_{i < j, A_i \leq A_j} A_j - A_i \\ &= (\text{The sum of }A_j\text{ such that }i < j\text{ and }A_i \leq A_j) \\ &- (\text{The number of }A_j\text{ such that }i < j\text{ and }A_i \leq A_j) \times A_i . \end{aligned} \]
Using this fact, this problem can be solved by the sweep line algorithm as follows.
- Prepare a data structure that manages the following two values:
- A multiset \(S_0\) that supports two kinds of query, insertion of an element and retrieval of the number of elements not less than \(x\).
- A multiset \(S_1\) that supports two kinds of query, insertion of an element and retrieval of the sum of elements not less than \(x\).
- Also, prepare a variable \(\mathrm{ans}\) that stores the answer. Initially, let \(\mathrm{ans} = 0\).
- For each \(i = N, N-1, \dots, 2, 1\), perform the following.
- Let \(c\) be the response to the query against \(S_0\) with \(x = A_i\).
- Let \(s\) be the response to the query against \(S_1\) with \(x = A_i\).
- Add \(s - c \times A_i\) to \(\mathrm{ans}\).
- Insert \(A_i\) to \(S_0\) and \(S_1\).
- Print the resulting value of \(\mathrm{ans}\).
\(S_0\) and \(S_1\) can be achieved by a Fenwick Tree with coordinate compression; they process query in \(\mathrm{O}(\log N)\) time each.
Therefore, the problem can be solved in a total of \(\mathrm{O}(N \log N)\) time, which is fast enough.
- Sample code (Python)
from atcoder.fenwicktree import FenwickTree
import bisect
N = int(input())
A = list(map(int, input().split()))
B = sorted([x for x in set(A)])
M = len(B)
sum0 = FenwickTree(M)
sum1 = FenwickTree(M)
ans = 0
for i in reversed(range(N)):
k = bisect.bisect_left(B, A[i])
c = sum0.sum(k, M)
s = sum1.sum(k, M)
ans += s - c * A[i]
sum0.add(k, 1)
sum1.add(k, A[i])
print(ans)
posted:
last update: