E - Count A%B=C 解説 by en_translator
Rephrase the conditions
Since we have the condition “\(a \bmod b = c\),” the conforming triplets can be expressed in the form \((a,b,a \bmod b)\). To this end, we rephrase the conditions about \((a,b,c)\) into those about \((a,b)\) and count the pairs satisfying the new criteria.
If \(a < b\), then \(c = a \bmod b = a\), so the condition “\(a,b\), and \(c\) are pairwise distinct” is already satisfied. Conversely, if \(a > b\), then \(c = a \bmod b < b\), so the condition “\(a,b\), and \(c\) are pairwise distinct” is always guaranteed.
Also, the condition \(c = a \bmod b\) being positive is guaranteed by \(a\) not being a multiple of \(b\).
Therefore, it is sufficient to count the number of integer pairs \((a,b)\) with:
- \(1 \leq b < a \leq N\)
- \(a\) is not a multiple of \(b\).
Counting
For a fixed \(b \in \{1, \dots, N\}\), let \(f(N, b)\) be the number of integers \(a\) satisfying the conditions.
There are \((N-b)\) integers between \(b+1\) and \(N\), and \(\left( \lfloor \frac{N}{b} \rfloor - 1 \right)\) integers among them is a multiple of \(b\). Therefore, \(f(N,b) = (N - b) - (\lfloor \frac{N}{b} \rfloor - 1)\).
Hence, the sought value is the sum of \(f(N,b)\) over \(b = 1, \dots, N\), namely
\[\sum_{b = 1}^{N} \left(N - b + 1 - \left\lfloor \frac{N}{b} \right\rfloor\right ).\]
Optimization
Under the constraints \(N \leq 10^{12}\), it is unlikely that evaluating the expression above finishes within the time limit.
Let us split the expression into the following two terms:
\[\sum_{b = 1}^{N} \left(N - b + 1 \right) - \sum_{b=1}^N \left\lfloor \frac{N}{b} \right\rfloor.\]
The former term
The first term can be explicitly written as \(N + (N-1) + \dots + 2 + 1\). This is a sum of an arithmetic progression, and the value can be simply written as \(\displaystyle\frac{N(N+1)}{2}\).
The latter term
Now we consider the latter term.
We think about \(\displaystyle\left\lfloor \frac{N}{b} \right\rfloor\). If \(b > \sqrt{N}\), we have \(\displaystyle\left\lfloor \frac{N}{b} \right\rfloor < \sqrt{N}\), so there are at most \(2 \sqrt{N}\) distinct values for \(\displaystyle\left\lfloor \frac{N}{b} \right\rfloor\).
So we evaluate the sum by the following two steps:
- For \(k = 1, \cdots, \lfloor \sqrt{N} \rfloor\), find the range of \(k\) such that \(\displaystyle\left\lfloor \frac{N}{b} \right\rfloor = k\), and compute
- For those \(b\) not covered by the above, evaluate \(\displaystyle\left\lfloor \frac{N}{b} \right\rfloor\) individually and sum them up.
Those \(b\) not covered by the former satisfies \(\displaystyle\left\lfloor \frac{N}{b} \right\rfloor \geq \displaystyle\left\lfloor \sqrt{N} \right\rfloor + 1\). Since
\[ \frac{N}{b} \geq \displaystyle\left\lfloor \frac{N}{b} \right\rfloor \geq \displaystyle\left\lfloor \sqrt{N} \right\rfloor + 1 > \sqrt{N}, \]
we have \(b < \sqrt{N}\). Therefore, the number of \(b\) handled by the latter does not exceed \(\lfloor \sqrt{N} \rfloor\), so the former and the latter can be both computed in \(O(\sqrt{N})\) time.
Since the value of \(N\) is large, beware of overflows throughout the computation when implementing. Although this problem does not require, directly storing huge values as bigints makes it quite computational heavy, so you might need to take modulus time to time during computation.
The range of \(b\) satisfying \(\displaystyle\left\lfloor \frac{N}{b} \right\rfloor = k\)
The range of integers \(b\) satisfying \(\displaystyle\left\lfloor \frac{N}{b} \right\rfloor = k\) can be obtained by the following deformation:
\[ \begin{aligned} \displaystyle\left\lfloor \frac{N}{b} \right\rfloor = k &\Longleftrightarrow k \leq \frac{N}{b} < k+1 \\ &\Longleftrightarrow \frac{N}{k+1} < b \leq \frac{N}{k} \\ &\Longleftrightarrow \left\lfloor\frac{N}{k+1}\right\rfloor + 1 \leq b \leq \left\lfloor\frac{N}{k}\right\rfloor. \end{aligned} \]
Sample code (C++)
#include <iostream>
#include <cstdio>
#include <string>
#include <vector>
#include <array>
using std::cin;
using std::cout;
using std::cerr;
using std::endl;
using std::string;
using std::to_string;
using std::pair;
using std::make_pair;
using std::vector;
using std::min;
using std::max;
using std::array;
#include <atcoder/all>
using mint = atcoder::modint998244353;
typedef long long int ll;
ll n;
void solve () {
mint fsum = 0; // will be the sum of floor(n/b) (1 <= b <= n)
ll prevk = -1;
ll idx = 1;
while (true) {
// floor(n/k) == idx
// <=> idx <= n/k < idx+1
// <=> n/(idx+1) < k <= n/idx
// <=> floor(n/(idx+1)) + 1 <= k <= floor(n/idx)
ll l = n / (idx+1) + 1;
ll r = n / idx;
if (l <= r) fsum += (mint)(r-l+1) * (mint)idx;
prevk = l;
if (prevk <= n / prevk) break;
idx++;
}
while (--prevk >= 1) {
fsum += (mint)(n / prevk);
}
mint ans = 0; // the sum of (n-b+1) - floor(n/b)
ans += (mint)n * (mint)(n+1) / (mint)2;
ans -= fsum;
cout << ans.val() << "\n";
return;
}
int main (void) {
std::ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n;
solve();
return 0;
}
投稿日時:
最終更新: