G - Sum of Prod of Mod of Linear Editorial by en_translator
The knowledge on generalized floor sum is required. Generalized floor sum is an algorithm that, given integers \(p,q,N,M,A,B\), finds \(\displaystyle f_{p,q}(N,M,A,B)=\sum_{i=0}^{N-1}i^p\left\lfloor \frac{Ai+B}M \right\rfloor^q\) in \(O((p+q)^4\log M)\) time. If you need more details on generalized floor sum, please refer to this article (in Japanese).
If \(A=0\), then the answer is obviously \(NB_1B_2\), so we assume \(A>0\).
By symmetry, we assume \(B_1\le B_2\).
We first deform the asked expression.
\[ \begin{aligned} &\phantom{=} \sum_{k=0}^{N-1}\left\lbrace (Ak+B_1)\ \text{mod}\ M \right\rbrace\left\lbrace (Ak+B_2)\ \text{mod}\ M \right\rbrace\\ &=\sum_{k=0}^{N-1}\left((Ak+B_1)-M\left\lfloor \frac{Ak+B_1}M \right\rfloor \right)\left((Ak+B_2)-M\left\lfloor \frac{Ak+B_2}M \right\rfloor \right)\\ &=\sum_{k=0}^{N-1}(Ak+B_1)(Ak+B_2)-M\sum_{k=0}^{N-1}\left((Ak+B_1)\left\lfloor\frac{Ak+B_2}M \right\rfloor+(Ak+B_2)\left\lfloor\frac{Ak+B_1}M \right\rfloor \right)+M^2\sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_1}M \right\rfloor\left\lfloor\frac{Ak+B_2}M \right\rfloor \\ \end{aligned} \]
\(\displaystyle \sum_{k=0}^{N-1}(Ak+B_1)(Ak+B_2)-M\sum_{k=0}^{N-1}\left((Ak+B_1)\left\lfloor\frac{Ak+B_2}M \right\rfloor+(Ak+B_2)\left\lfloor\frac{Ak+B_1}M \right\rfloor \right)\) can be evaluated in \(O(\log M)\) using generalized floor sum, so it is sufficient to evaluate \(\displaystyle \sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_1}M \right\rfloor\left\lfloor\frac{Ak+B_2}M \right\rfloor\).
Here, let us consider the difference between \(\displaystyle \sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_1}M \right\rfloor\left\lfloor\frac{Ak+B_2}M \right\rfloor\) and \(\displaystyle \sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_2}M \right\rfloor^2\). If we can find the difference, then the latter can be computed with generalized floor sum, allowing us to find the answer First, we have
\[\sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_1}M \right\rfloor\left\lfloor\frac{Ak+B_2}M \right\rfloor-\sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_2}M \right\rfloor^2=-\sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_2}M \right\rfloor\left( \left\lfloor\frac{Ak+B_2}M \right\rfloor -\left\lfloor\frac{Ak+B_1}M \right\rfloor \right).\]
Here, the relation \(0\le B_1 \le B_2 < M\) constrains the value \(\displaystyle \left\lfloor\frac{Ak+B_2}M \right\rfloor -\left\lfloor\frac{Ak+B_1}M \right\rfloor\) taking either \(0\) or \(1\).
For a fixed \(c\), the range of \(k\) satisfying \(\displaystyle c=\left\lfloor\frac{Ak+B_2}M \right\rfloor\) can be written as \(d_c^2\le k < d^2_{c+1}\), where
\[ d_c^i = \min\left(N,\left\lceil\frac{cM-B_i}A \right\rceil\right) \]
Within this range, the range of \(k\) satisfying \(\displaystyle c-1=\left\lfloor\frac{Ak+B_1}M \right\rfloor\) can be written as \( d^2_{c}\le k < d^1_{c}\). Writing \(\displaystyle X=\left\lfloor\frac{A(N-1)+B_2}M \right\rfloor\), the sought difference is
\[ \begin{aligned} &\phantom{=}- \sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_2}M \right\rfloor\left( \left\lfloor\frac{Ak+B_2}M \right\rfloor -\left\lfloor\frac{Ak+B_1}M \right\rfloor \right)\\ &=-\sum_{c=0}^Xc\sum_{k=d_{c}^2}^{d_{c}^1-1}1\\ &=-\sum_{c=0}^X c(d_{c}^1-d_{c}^2)\\ &=-\sum_{c=0}^{X-1} c(d_{c}^1-d_{c}^2)-X(d_{X}^1-d_{X}^2).\\ \end{aligned} \]
For \(c<X\), we have \(\displaystyle d_c^i =\left\lceil\frac{cM-B_i}A \right\rceil\), so this difference can also be found with generalized sum.
By implementing it appropriately, the problem can be solved. The time complexity is \(O(\log M)\) per query.
def floor_sum(n, m, a, b):
if m == 0:
return 0, 0, 0
a1, a2 = a // m, a % m
b1, b2 = b // m, b % m
y = (a2 * n + b2) // m
ff, gg, hh = floor_sum(y, a2, m, m + a2 - b2 - 1)
nn = n * (n - 1) // 2
f = n * y - ff
g = n * y * y - ff - 2 * hh
h = nn * y + (ff - gg) // 2
g += n * (2 * n - 1) * (n - 1) * a1 * a1 // 6
g += 2 * nn * a1 * b1
g += b1 * b1 * n
g += 2 * a1 * h
g += 2 * b1 * f
f += a1 * nn + b1 * n
h += nn * (a1 * (2 * n - 1) + 3 * b1) // 3
return f, g, h
import sys
input = sys.stdin.readline
for _ in range(int(input())):
n, m, a, b1, b2 = map(int, input().split())
if a == 0:
print(n * b1 * b2)
continue
if b1 >= b2:
b1, b2 = b2, b1
ans = 0
ans += a * a * (n - 1) * n * (2 * n - 1) // 6
ans += a * (b1 + b2) * (n - 1) * n // 2
ans += b1 * b2 * n
f, g, h = floor_sum(n, m, a, b1)
ans -= m * (a * h + b2 * f)
f, g, h = floor_sum(n, m, a, b2)
ans -= m * (a * h + b1 * f)
res = g
x = (a * (n - 1) + b2) // m
f, g, h = floor_sum(x, a, m, a - b1 - 1)
res -= h + x * min(n, (x * m + a - b1 - 1) // a)
f, g, h = floor_sum(x, a, m, a - b2 - 1)
res += h + x * min(n, (x * m + a - b2 - 1) // a)
ans += m * m * res
print(ans)
posted:
last update: