公式

E - You WILL Like Sigma Problem 解説 by en_translator


Observations

Let \(\lfloor i/j \rfloor\) denote the quotient when dividing \(i\) by \(j\).

For fixed \(j\) and \(k\), the integers \(i\) satisfying \(\lfloor i/j \rfloor = k\) form a range \(jk \leq i \leq \min(j(k+1) - 1,\ N)\). Since \((i \bmod j) = i - jk\) within this range,

\[ A_i \cdot B_j \cdot (i \bmod j) = B_j \cdot (A_i \cdot i) - B_j \cdot (A_i \cdot jk).\]

How can we evaluate the sum over the range of \(i\)? The sum of the first term is \(B_j\) times the sum of \(A_i \cdot i\) within the range; the second is \(jkB_j\) times the sum of \(A_i\). Both can be evaluated in \(O(1)\) time per query, by precomputing the cumulative sums in \(O(1)\) time.

For a fixed \(j\), the range of \(k\) is between \(1\) and \(\lfloor N/j \rfloor\). Thus, the number of possible pairs of \((j, k)\) is at most \(\frac{N}{1} + \frac{N}{2} + \cdots + \frac{N}{N}\), which is known to be in the order of \(O(N \log N)\) (harmonic series).

This runs in a total of \(O(N \log N)\) time, which is fast enough.

Sample code (C++)

#include <iostream>
using std::cin;
using std::cout;
using std::cerr;
using std::endl;
#include <vector>
using std::vector;
using std::pair;
using std::make_pair;
using std::min;

typedef long long int ll;

#include <atcoder/modint>
using mint = atcoder::modint998244353;

ll n, m;
vector<mint> a, b;

void solve () {
	// line up their indexs
	a.insert(a.begin(), 0);
	n++;
	b.insert(b.begin(), 0);
	m++;

	// sum[i]  := sum_{k < i} a[k] * 1
	// sum2[i] := sum_{k < i} a[k] * k
	vector<mint> sum(n+1, 0), sum2(n+1, 0);
	for (ll i = 0; i < n; i++) {
		sum[ i+1] = sum[ i] + a[i] * 1;
		sum2[i+1] = sum2[i] + a[i] * i;
	}

	mint ans = 0;
	for (ll bi = 1; bi < m; bi++) {
		mint bans = 0;
		for (ll i = 0; i * bi < n; i++) {
			ll l = (i+0) * bi;
			ll r = min((i+1) * bi, n);

			// sum_{l <= k < r} a[k] * (k - i*bi)
			bans += (sum2[r] - sum2[l]);
			bans -= (sum[r] - sum[l]) * (i*bi);
		}
		ans += bans * b[bi];
	}

	cout << ans.val() << "\n";
}

int main (void) {
	std::cin.tie(nullptr);
	std::ios_base::sync_with_stdio(false);

	cin >> n >> m;
	
	a.resize(n);
	for (ll i = 0; i < n; i++) {
		ll x; cin >> x;
		a[i] = x;
	}
	b.resize(m);
	for (ll i = 0; i < m; i++) {
		ll x; cin >> x;
		b[i] = x;
	}

	solve();

	return 0;
}

投稿日時:
最終更新: