G - Sum of Pow of Mod of Linear 解説
by
sounansya
公式解説の補足
公式解説では \(N < M\) に限定して考えていましたが、一般に \(\lbrace (k,(Ak+B) \bmod M)\ |\ 0\le k < N \rbrace\) を \(O(\sqrt N)\) 個の等差数列に分解することを考えます。
まず \(D=O(\sqrt N)\) となる定数 \(D\) を取り、\(\min(Ai \bmod M, (-Ai) \bmod M)\) が最小となる \(1\le i < D\) を探します。この \(i\) を \(i_0\) とします。
この \(i_0\) に対し、鳩の巣原理から \(\displaystyle \min(Ai_0 \bmod M, (-Ai_0) \bmod M) \le \frac{M}{D}\) が成り立つことに留意してください。
もし \(Ai_0 \bmod M > (-Ai_0) \bmod M\) が成り立つ場合は \((A,B) \leftarrow (-A,M-1-B)\) とした場合を考えることで \(\displaystyle Ai_0 \bmod M \le \frac{M}{D}\) となる場合に帰着させることができます。以降はこの場合を考えます。
\(\lbrace (k,(Ak+B) \bmod M)\ |\ 0\le k < N \rbrace\) を一旦 \(k \bmod i_0\) の値で分類します。各等差数列は \(d=0,1,\ldots,i_0-1\) に対し \(\displaystyle \left\lbrace (ki_0+d,(A(ki_0+d)+B) \bmod M)\ |\ 0\le k < \left\lceil \frac{N-d}{i_0}\right\rceil \right\rbrace\) と表されます。これを公式解説のステップ 1. のように前から愚直に等差数列となるギリギリまで取ると、\(\displaystyle Ai_0 \bmod M \le \frac{M}{D}\) より \(\displaystyle O\left(\frac{1}{D}\times \frac{N}{i_0}\right)\) 個の等差数列に分かれます。したがって、全体で \(\displaystyle O\left(\frac{N}{D}\right)=O(\sqrt{N})\) 個の等差数列に分けることができます。
#include <atcoder/modint>
#include <bits/stdc++.h>
using namespace std;
vector<tuple<long, long, long, long, long>> sqrt_floor(long n, long m, long a, long b) {
// [(x, y, Δx, Δy, n), ... ]
if (n < 10) {
vector<tuple<long, long, long, long, long>> ans;
for (int i = 0; i < n; i++) ans.emplace_back(i, (a * i + b) % m, 0, 0, 1);
return ans;
}
const long D = sqrt(n);
long best_idx = -1, best_val = m;
for (int i = 1; i <= D; i++) {
long v = a * i % m;
v = min(v, m - v);
if (best_val > v) {
best_val = v;
best_idx = i;
}
}
const long delta_i = best_idx;
const long delta_j = delta_i * a % m;
if (delta_j > m - delta_j) {
vector<tuple<long, long, long, long, long>> ans = sqrt_floor(n, m, (m - a) % m, m - 1 - b);
for (auto &[x, y, xx, yy, n] : ans) {
y = m - 1 - y;
yy = -yy;
}
return ans;
}
const long a2 = a * delta_i % m;
vector<tuple<long, long, long, long, long>> ans;
for (int di = 0; di < best_idx; di++) {
const long b2 = (b + a * di) % m;
const long n2 = (n - di + best_idx - 1) / best_idx;
const long lim = (a2 * (n2 - 1) + b2) / m;
long le = 0;
for (int k = 0; k <= lim; k++) {
const long ri = k == lim ? n2 : (m * (k + 1) - b2 + a2 - 1) / a2;
ans.emplace_back(delta_i * le + di, (a2 * le + b2) % m, delta_i, delta_j, ri - le);
le = ri;
}
}
return ans;
}
using modint = atcoder::modint;
modint mp(modint x, long n) { // 1 + x + x^2 + ... + x^{n-1}
if (n == 0) return 0;
if (n % 2) return 1 + x * mp(x, n - 1);
return (1 + x) * mp(x * x, n / 2);
}
int main() {
cin.tie(nullptr);
ios::sync_with_stdio(false);
int t;
cin >> t;
while (t--) {
long n, m, a, b, x, r;
cin >> n >> m >> a >> b >> x >> r;
auto res = sqrt_floor(n, m, a, b);
modint::set_mod(r);
modint ans = 0;
for (auto [_, y, _, yy, len] : res) {
if (yy < 0) {
y += (len - 1) * yy;
yy *= -1;
}
ans += modint(x).pow(y) * mp(modint(x).pow(yy), len);
}
cout << ans.val() << '\n';
}
}
投稿日時:
最終更新:
