G - Redistribution of Piles Editorial by en_translator
We call the two operations operation \(A\) and operation \(B\). The key observation follows:
- We can assume that operation \(A\) is never performed right after operation \(B\), without changing the answer.
This is because, after operation \(B\), all dishes contain at least one stone, so performing operation \(A\) rolls back the last operation, so it is meaningless.
Similarly, we can see the following fact:
- We can assume that operation \(B\) is never performed right after taking stones from all the dishes in an operation \(A\), without changing the answer.
Conversely, we can assert that there is a bijection between a resulting sequence and a sequence of operations subject to the two rules above. Thus, the problem is boiled down to counting appropriate sequences of operations.
Let us use integers \(x\) and \(y\) to express a sequence of operations in the following form: “perform \(A\) \(x\) times, \(B\) \(y\) times.” For a fixed \(x\), the range of \(y\) and the count turns out to be the following.
- If \(x \leq \min \lbrace A \rbrace\): \(y = 0\) (one candidates) (performing \(B\) violates the second condition)
- If \(\min \lbrace A \rbrace \lt x\): \(0 \leq y \leq \lfloor s/N \rfloor\), if the bag has \(s\) balls (\( \lfloor s/N \rfloor + 1\) candidates)
Thus, we can exhaustively scan \(x\) to solve this problem in a total of \(\mathrm{O}(\max \lbrace A \rbrace)\) time, but the complexity is so large that it will lead to TLE (Time Limit Exceeded).
However, the expressions is boiled down to \(\mathrm{O}(N)\) expressions of in the form of \(\sum_{L \leq n \leq R} \lfloor (ax+b)/m \rfloor\), so we can apply an algorithm called floor sum. (Floor sum can be easily computed with AtCoder Library. ACL (AtCoder Library)’s reference)
With floor sum, this problem can be solved in a total of \(\mathrm{O}(N \log M)\) time (where \(M = \max \lbrace A \rbrace\)), which is fast enough.
- Sample code (C++)
#include <bits/stdc++.h>
using namespace std;
#include "atcoder/math.hpp"
#include "atcoder/modint.hpp"
using mint = atcoder::modint998244353;
using ll = long long;
int main() {
ll N;
cin >> N;
vector<ll> A(N);
for(auto&x : A) cin >> x;
sort(begin(A), end(A));
ll s = A[0] * N;
mint ans = A[0] + 1;
for (int i = 1; i < N; i++) {
ll L = A[i - 1];
ll R = A[i];
ll a = N - i;
ll b = s - L * (N - i);
ll m = N;
ll cur = 0;
cur += atcoder::floor_sum(R + 1, m, a, b);
cur -= atcoder::floor_sum(L + 1, m, a, b);
cur += R - L;
ans += cur;
s += (R - L) * (N - i);
}
cout << ans.val() << "\n";
}
posted:
last update: