G - Discrete Logarithm Problems 解説 by MMNMM


各要素の位数を求めるパートを公式解説より少し高速に行う方法について解説します。

そのパート以外は完全に同じなので、前後の説明を省略して文字も流用します。

  • \(M\) を \(P-1\) とする
  • \(k=1,2,\ldots,n\) の順に以下を繰り返す
    • \(M\) を \(p _ k\) で割れるだけ割り、\(t=x ^ M\) とする。
    • \(t\neq1 \bmod P\) である限り、\(M\) を \(Mp _ k\) に、\(t\) を \(t ^ {p _k}\) に置き換える操作を繰り返す
  • 全ての操作が終了した時点の \(M\) が \(x\) の位数である

各 \(k\) について、ひとつめの操作が終わった後の \(M\) と、二回目の操作が繰り返される回数 \(c\) を考えると、累乗の計算にかかる時間は全体で \(O(\log M+c\log p _ k)=O(\log (Mp _ k ^ c))=O(\log P)\) 時間となることがわかります。 それ以外の部分も各 \(k\) についてたかだか \(O(\log P)\) 時間となるので、全体で \(O(n\log P)\) 時間に改善されています。

実装例は以下のようになります。

#include <bits/extc++.h>

// ab mod m (m < 10^13) を求める
constexpr unsigned long prod_mod(unsigned long a, unsigned long b, unsigned long m) {
    constexpr unsigned long mask{1048575};
    unsigned long a_high{a >> 40}, a_mid{(a >> 20) & mask}, a_low{a & mask};
    unsigned long r{};
    (r += b * a_high) %= m;
    (r <<= 20) %= m;
    (r += b * a_mid) %= m;
    (r <<= 20) %= m;
    (r += b * a_low) %= m;
    return r;
}

// a^b mod m を求める
constexpr unsigned long pow_mod(unsigned long a, unsigned long b, unsigned long m) {
    unsigned long r{1};
    while (b) {
        if (b & 1)
            r = prod_mod(r, a, m);
        a = prod_mod(a, a, m);
        b /= 2;
    }
    return r;
}

int main() {
    using namespace std;
    unsigned N;
    unsigned long P;
    cin >> N >> P;

    // P - 1 の素因数と約数を求める
    vector<unsigned long> prime_divisors, divisors;
    {
        unsigned long p{P - 1};
        for (unsigned long i{2}; i * i <= p; ++i)
            if (1 < i && p % i == 0) {
                prime_divisors.emplace_back(i);
                while (p % i == 0)
                    p /= i;
            }
        if (p > 1)
            prime_divisors.emplace_back(p);
    }
    for (unsigned long i{1}, p{P - 1}; i * i <= p; ++i)
        if (p % i == 0) {
            divisors.emplace_back(i);
            if (i * i < p)divisors.emplace_back(p / i);
        }
    ranges::sort(divisors);

    // A[i] の位数として出現する値の頻度をまとめる
    unordered_map<unsigned long, unsigned> ord_count;
    for (const auto a : views::istream<unsigned long>(cin)) {
        // a の位数を求める: O(nlog P) 時間
        unsigned long ord{P - 1};
        for (const auto p : prime_divisors) {
            while (ord % p == 0)
                ord /= p;
            auto now{pow_mod(a, ord, P)};
            while (now != 1) {
                ord *= p;
                now = pow_mod(now, p, P);
            }
        }
        ++ord_count[ord];
    }

    // 位数を整除関係で累積和
    auto ord_count_accumulate{ord_count};
    for (const auto p : prime_divisors)
        for (const auto d : divisors | views::reverse)
            if (d % p == 0 && ord_count_accumulate.contains(d))
                ord_count_accumulate[d / p] += ord_count_accumulate[d];

    // 組の個数をカウントして出力
    unsigned long ans{};
    for (const auto&[x, y] : ord_count)
        ans += static_cast<unsigned long>(y) * ord_count_accumulate[x];
    cout << ans << endl;

    return 0;
}

投稿日時:
最終更新: