G - Discrete Logarithm Problems Editorial
by
MMNMM
各要素の位数を求めるパートを公式解説より少し高速に行う方法について解説します。
そのパート以外は完全に同じなので、前後の説明を省略して文字も流用します。
- \(M\) を \(P-1\) とする
- \(k=1,2,\ldots,n\) の順に以下を繰り返す
- \(M\) を \(p _ k\) で割れるだけ割り、\(t=x ^ M\) とする。
- \(t\neq1 \bmod P\) である限り、\(M\) を \(Mp _ k\) に、\(t\) を \(t ^ {p _k}\) に置き換える操作を繰り返す
- 全ての操作が終了した時点の \(M\) が \(x\) の位数である
各 \(k\) について、ひとつめの操作が終わった後の \(M\) と、二回目の操作が繰り返される回数 \(c\) を考えると、累乗の計算にかかる時間は全体で \(O(\log M+c\log p _ k)=O(\log (Mp _ k ^ c))=O(\log P)\) 時間となることがわかります。 それ以外の部分も各 \(k\) についてたかだか \(O(\log P)\) 時間となるので、全体で \(O(n\log P)\) 時間に改善されています。
実装例は以下のようになります。
#include <bits/extc++.h>
// ab mod m (m < 10^13) を求める
constexpr unsigned long prod_mod(unsigned long a, unsigned long b, unsigned long m) {
constexpr unsigned long mask{1048575};
unsigned long a_high{a >> 40}, a_mid{(a >> 20) & mask}, a_low{a & mask};
unsigned long r{};
(r += b * a_high) %= m;
(r <<= 20) %= m;
(r += b * a_mid) %= m;
(r <<= 20) %= m;
(r += b * a_low) %= m;
return r;
}
// a^b mod m を求める
constexpr unsigned long pow_mod(unsigned long a, unsigned long b, unsigned long m) {
unsigned long r{1};
while (b) {
if (b & 1)
r = prod_mod(r, a, m);
a = prod_mod(a, a, m);
b /= 2;
}
return r;
}
int main() {
using namespace std;
unsigned N;
unsigned long P;
cin >> N >> P;
// P - 1 の素因数と約数を求める
vector<unsigned long> prime_divisors, divisors;
{
unsigned long p{P - 1};
for (unsigned long i{2}; i * i <= p; ++i)
if (1 < i && p % i == 0) {
prime_divisors.emplace_back(i);
while (p % i == 0)
p /= i;
}
if (p > 1)
prime_divisors.emplace_back(p);
}
for (unsigned long i{1}, p{P - 1}; i * i <= p; ++i)
if (p % i == 0) {
divisors.emplace_back(i);
if (i * i < p)divisors.emplace_back(p / i);
}
ranges::sort(divisors);
// A[i] の位数として出現する値の頻度をまとめる
unordered_map<unsigned long, unsigned> ord_count;
for (const auto a : views::istream<unsigned long>(cin)) {
// a の位数を求める: O(nlog P) 時間
unsigned long ord{P - 1};
for (const auto p : prime_divisors) {
while (ord % p == 0)
ord /= p;
auto now{pow_mod(a, ord, P)};
while (now != 1) {
ord *= p;
now = pow_mod(now, p, P);
}
}
++ord_count[ord];
}
// 位数を整除関係で累積和
auto ord_count_accumulate{ord_count};
for (const auto p : prime_divisors)
for (const auto d : divisors | views::reverse)
if (d % p == 0 && ord_count_accumulate.contains(d))
ord_count_accumulate[d / p] += ord_count_accumulate[d];
// 組の個数をカウントして出力
unsigned long ans{};
for (const auto&[x, y] : ord_count)
ans += static_cast<unsigned long>(y) * ord_count_accumulate[x];
cout << ans << endl;
return 0;
}
posted:
last update: