Official

G - Sum of Prod of Mod of Linear Editorial by en_translator


The knowledge on generalized floor sum is required. Generalized floor sum is an algorithm that, given integers \(p,q,N,M,A,B\), finds \(\displaystyle f_{p,q}(N,M,A,B)=\sum_{i=0}^{N-1}i^p\left\lfloor \frac{Ai+B}M \right\rfloor^q\) in \(O((p+q)^4\log M)\) time. If you need more details on generalized floor sum, please refer to this article (in Japanese).


If \(A=0\), then the answer is obviously \(NB_1B_2\), so we assume \(A>0\).

By symmetry, we assume \(B_1\le B_2\).

We first deform the asked expression.

\[ \begin{aligned} &\phantom{=} \sum_{k=0}^{N-1}\left\lbrace (Ak+B_1)\ \text{mod}\ M \right\rbrace\left\lbrace (Ak+B_2)\ \text{mod}\ M \right\rbrace\\ &=\sum_{k=0}^{N-1}\left((Ak+B_1)-M\left\lfloor \frac{Ak+B_1}M \right\rfloor \right)\left((Ak+B_2)-M\left\lfloor \frac{Ak+B_2}M \right\rfloor \right)\\ &=\sum_{k=0}^{N-1}(Ak+B_1)(Ak+B_2)-M\sum_{k=0}^{N-1}\left((Ak+B_1)\left\lfloor\frac{Ak+B_2}M \right\rfloor+(Ak+B_2)\left\lfloor\frac{Ak+B_1}M \right\rfloor \right)+M^2\sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_1}M \right\rfloor\left\lfloor\frac{Ak+B_2}M \right\rfloor \\ \end{aligned} \]

\(\displaystyle \sum_{k=0}^{N-1}(Ak+B_1)(Ak+B_2)-M\sum_{k=0}^{N-1}\left((Ak+B_1)\left\lfloor\frac{Ak+B_2}M \right\rfloor+(Ak+B_2)\left\lfloor\frac{Ak+B_1}M \right\rfloor \right)\) can be evaluated in \(O(\log M)\) using generalized floor sum, so it is sufficient to evaluate \(\displaystyle \sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_1}M \right\rfloor\left\lfloor\frac{Ak+B_2}M \right\rfloor\).

Here, let us consider the difference between \(\displaystyle \sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_1}M \right\rfloor\left\lfloor\frac{Ak+B_2}M \right\rfloor\) and \(\displaystyle \sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_2}M \right\rfloor^2\). If we can find the difference, then the latter can be computed with generalized floor sum, allowing us to find the answer First, we have

\[\sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_1}M \right\rfloor\left\lfloor\frac{Ak+B_2}M \right\rfloor-\sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_2}M \right\rfloor^2=-\sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_2}M \right\rfloor\left( \left\lfloor\frac{Ak+B_2}M \right\rfloor -\left\lfloor\frac{Ak+B_1}M \right\rfloor \right).\]

Here, the relation \(0\le B_1 \le B_2 < M\) constrains the value \(\displaystyle \left\lfloor\frac{Ak+B_2}M \right\rfloor -\left\lfloor\frac{Ak+B_1}M \right\rfloor\) taking either \(0\) or \(1\).

For a fixed \(c\), the range of \(k\) satisfying \(\displaystyle c=\left\lfloor\frac{Ak+B_2}M \right\rfloor\) can be written as \(d_c^2\le k < d^2_{c+1}\), where

\[ d_c^i = \min\left(N,\left\lceil\frac{cM-B_i}A \right\rceil\right) \]

Within this range, the range of \(k\) satisfying \(\displaystyle c-1=\left\lfloor\frac{Ak+B_1}M \right\rfloor\) can be written as \( d^2_{c}\le k < d^1_{c}\). Writing \(\displaystyle X=\left\lfloor\frac{A(N-1)+B_2}M \right\rfloor\), the sought difference is

\[ \begin{aligned} &\phantom{=}- \sum_{k=0}^{N-1}\left\lfloor\frac{Ak+B_2}M \right\rfloor\left( \left\lfloor\frac{Ak+B_2}M \right\rfloor -\left\lfloor\frac{Ak+B_1}M \right\rfloor \right)\\ &=-\sum_{c=0}^Xc\sum_{k=d_{c}^2}^{d_{c}^1-1}1\\ &=-\sum_{c=0}^X c(d_{c}^1-d_{c}^2)\\ &=-\sum_{c=0}^{X-1} c(d_{c}^1-d_{c}^2)-X(d_{X}^1-d_{X}^2).\\ \end{aligned} \]

For \(c<X\), we have \(\displaystyle d_c^i =\left\lceil\frac{cM-B_i}A \right\rceil\), so this difference can also be found with generalized sum.

By implementing it appropriately, the problem can be solved. The time complexity is \(O(\log M)\) per query.

Sample code (Python3)

def floor_sum(n, m, a, b):
    if m == 0:
        return 0, 0, 0
    a1, a2 = a // m, a % m
    b1, b2 = b // m, b % m
    y = (a2 * n + b2) // m
    ff, gg, hh = floor_sum(y, a2, m, m + a2 - b2 - 1)
    nn = n * (n - 1) // 2
    f = n * y - ff
    g = n * y * y - ff - 2 * hh
    h = nn * y + (ff - gg) // 2
    g += n * (2 * n - 1) * (n - 1) * a1 * a1 // 6
    g += 2 * nn * a1 * b1
    g += b1 * b1 * n
    g += 2 * a1 * h
    g += 2 * b1 * f
    f += a1 * nn + b1 * n
    h += nn * (a1 * (2 * n - 1) + 3 * b1) // 3
    return f, g, h

import sys

input = sys.stdin.readline
for _ in range(int(input())):
    n, m, a, b1, b2 = map(int, input().split())
    if a == 0:
        print(n * b1 * b2)
        continue
    if b1 >= b2:
        b1, b2 = b2, b1
    ans = 0
    ans += a * a * (n - 1) * n * (2 * n - 1) // 6
    ans += a * (b1 + b2) * (n - 1) * n // 2
    ans += b1 * b2 * n
    f, g, h = floor_sum(n, m, a, b1)
    ans -= m * (a * h + b2 * f)
    f, g, h = floor_sum(n, m, a, b2)
    ans -= m * (a * h + b1 * f)
    res = g
    x = (a * (n - 1) + b2) // m
    f, g, h = floor_sum(x, a, m, a - b1 - 1)
    res -= h + x * min(n, (x * m + a - b1 - 1) // a)
    f, g, h = floor_sum(x, a, m, a - b2 - 1)
    res += h + x * min(n, (x * m + a - b2 - 1) // a)
    ans += m * m * res
    print(ans)

posted:
last update: