E - You WILL Like Sigma Problem 解説 by en_translator
Observations
Let \(\lfloor i/j \rfloor\) denote the quotient when dividing \(i\) by \(j\).
For fixed \(j\) and \(k\), the integers \(i\) satisfying \(\lfloor i/j \rfloor = k\) form a range \(jk \leq i \leq \min(j(k+1) - 1,\ N)\). Since \((i \bmod j) = i - jk\) within this range,
\[ A_i \cdot B_j \cdot (i \bmod j) = B_j \cdot (A_i \cdot i) - B_j \cdot (A_i \cdot jk).\]
How can we evaluate the sum over the range of \(i\)? The sum of the first term is \(B_j\) times the sum of \(A_i \cdot i\) within the range; the second is \(jkB_j\) times the sum of \(A_i\). Both can be evaluated in \(O(1)\) time per query, by precomputing the cumulative sums in \(O(1)\) time.
For a fixed \(j\), the range of \(k\) is between \(1\) and \(\lfloor N/j \rfloor\). Thus, the number of possible pairs of \((j, k)\) is at most \(\frac{N}{1} + \frac{N}{2} + \cdots + \frac{N}{N}\), which is known to be in the order of \(O(N \log N)\) (harmonic series).
This runs in a total of \(O(N \log N)\) time, which is fast enough.
Sample code (C++)
#include <iostream>
using std::cin;
using std::cout;
using std::cerr;
using std::endl;
#include <vector>
using std::vector;
using std::pair;
using std::make_pair;
using std::min;
typedef long long int ll;
#include <atcoder/modint>
using mint = atcoder::modint998244353;
ll n, m;
vector<mint> a, b;
void solve () {
// line up their indexs
a.insert(a.begin(), 0);
n++;
b.insert(b.begin(), 0);
m++;
// sum[i] := sum_{k < i} a[k] * 1
// sum2[i] := sum_{k < i} a[k] * k
vector<mint> sum(n+1, 0), sum2(n+1, 0);
for (ll i = 0; i < n; i++) {
sum[ i+1] = sum[ i] + a[i] * 1;
sum2[i+1] = sum2[i] + a[i] * i;
}
mint ans = 0;
for (ll bi = 1; bi < m; bi++) {
mint bans = 0;
for (ll i = 0; i * bi < n; i++) {
ll l = (i+0) * bi;
ll r = min((i+1) * bi, n);
// sum_{l <= k < r} a[k] * (k - i*bi)
bans += (sum2[r] - sum2[l]);
bans -= (sum[r] - sum[l]) * (i*bi);
}
ans += bans * b[bi];
}
cout << ans.val() << "\n";
}
int main (void) {
std::cin.tie(nullptr);
std::ios_base::sync_with_stdio(false);
cin >> n >> m;
a.resize(n);
for (ll i = 0; i < n; i++) {
ll x; cin >> x;
a[i] = x;
}
b.resize(m);
for (ll i = 0; i < m; i++) {
ll x; cin >> x;
b[i] = x;
}
solve();
return 0;
}
投稿日時:
最終更新: