公式

G - Sum of Pow of Mod of Linear 解説 by sounansya

公式解説の補足

公式解説では \(N < M\) に限定して考えていましたが、一般に \(\lbrace (k,(Ak+B) \bmod M)\ |\ 0\le k < N \rbrace\)\(O(\sqrt N)\) 個の等差数列に分解することを考えます。

まず \(D=O(\sqrt N)\) となる定数 \(D\) を取り、\(\min(Ai \bmod M, (-Ai) \bmod M)\) が最小となる \(1\le i < D\) を探します。この \(i\)\(i_0\) とします。

この \(i_0\) に対し、鳩の巣原理から \(\displaystyle \min(Ai_0 \bmod M, (-Ai_0) \bmod M) \le \frac{M}{D}\) が成り立つことに留意してください。

もし \(Ai_0 \bmod M > (-Ai_0) \bmod M\) が成り立つ場合は \((A,B) \leftarrow (-A,M-1-B)\) とした場合を考えることで \(\displaystyle Ai_0 \bmod M \le \frac{M}{D}\) となる場合に帰着させることができます。以降はこの場合を考えます。

\(\lbrace (k,(Ak+B) \bmod M)\ |\ 0\le k < N \rbrace\) を一旦 \(k \bmod i_0\) の値で分類します。各等差数列は \(d=0,1,\ldots,i_0-1\) に対し \(\displaystyle \left\lbrace (ki_0+d,(A(ki_0+d)+B) \bmod M)\ |\ 0\le k < \left\lceil \frac{N-d}{i_0}\right\rceil \right\rbrace\) と表されます。これを公式解説のステップ 1. のように前から愚直に等差数列となるギリギリまで取ると、\(\displaystyle Ai_0 \bmod M \le \frac{M}{D}\) より \(\displaystyle O\left(\frac{1}{D}\times \frac{N}{i_0}\right)\) 個の等差数列に分かれます。したがって、全体で \(\displaystyle O\left(\frac{N}{D}\right)=O(\sqrt{N})\) 個の等差数列に分けることができます。

実装例(C++)

#include <atcoder/modint>
#include <bits/stdc++.h>
using namespace std;
vector<tuple<long, long, long, long, long>> sqrt_floor(long n, long m, long a, long b) {
	// [(x, y, Δx, Δy, n), ... ]
	if (n < 10) {
		vector<tuple<long, long, long, long, long>> ans;
		for (int i = 0; i < n; i++) ans.emplace_back(i, (a * i + b) % m, 0, 0, 1);
		return ans;
	}
	const long D = sqrt(n);
	long best_idx = -1, best_val = m;
	for (int i = 1; i <= D; i++) {
		long v = a * i % m;
		v = min(v, m - v);
		if (best_val > v) {
			best_val = v;
			best_idx = i;
		}
	}
	const long delta_i = best_idx;
	const long delta_j = delta_i * a % m;
	if (delta_j > m - delta_j) {
		vector<tuple<long, long, long, long, long>> ans = sqrt_floor(n, m, (m - a) % m, m - 1 - b);
		for (auto &[x, y, xx, yy, n] : ans) {
			y = m - 1 - y;
			yy = -yy;
		}
		return ans;
	}
	const long a2 = a * delta_i % m;
	vector<tuple<long, long, long, long, long>> ans;
	for (int di = 0; di < best_idx; di++) {
		const long b2 = (b + a * di) % m;
		const long n2 = (n - di + best_idx - 1) / best_idx;
		const long lim = (a2 * (n2 - 1) + b2) / m;
		long le = 0;
		for (int k = 0; k <= lim; k++) {
			const long ri = k == lim ? n2 : (m * (k + 1) - b2 + a2 - 1) / a2;
			ans.emplace_back(delta_i * le + di, (a2 * le + b2) % m, delta_i, delta_j, ri - le);
			le = ri;
		}
	}
	return ans;
}
using modint = atcoder::modint;
modint mp(modint x, long n) { // 1 + x + x^2 + ... + x^{n-1}
	if (n == 0) return 0;
	if (n % 2) return 1 + x * mp(x, n - 1);
	return (1 + x) * mp(x * x, n / 2);
}
int main() {
	cin.tie(nullptr);
	ios::sync_with_stdio(false);
	int t;
	cin >> t;
	while (t--) {
		long n, m, a, b, x, r;
		cin >> n >> m >> a >> b >> x >> r;
		auto res = sqrt_floor(n, m, a, b);
		modint::set_mod(r);
		modint ans = 0;
		for (auto [_, y, _, yy, len] : res) {
			if (yy < 0) {
				y += (len - 1) * yy;
				yy *= -1;
			}
			ans += modint(x).pow(y) * mp(modint(x).pow(yy), len);
		}
		cout << ans.val() << '\n';
	}
}

投稿日時:
最終更新: