G - Fibonacci Product 解説 by MMNMM

いくつかの高速化について

いくつかの考察・式変形を行うことで(オーダーレベルでは変化しませんが)時間計算量を少なくすることができます。 この解説では、そのような考察を \(2\) つ紹介します。

公式解説の内容を前提とします。

以下、\(p=998244353\) とします。


1. 一周期に対する答え

数列 \(a _ n\) は周期 \(2p+2\) を持ちます。 \(N=2p+2\) のときの答えは非常にいい性質を持ちそうなので、この場合について考察します。 これは実際にいい性質を持っており、\(N=2p+2\) に対する答えを \(O(\log p)\) 時間で求めることができます。 この節ではこれを解説します。


\(\mathbb F _ p\) 係数の \(2\) 次正方行列 \(A=\begin{pmatrix}1&1\\1&0\end{pmatrix}\) をとります。
\(2\) 次正方行列全体からなる線形空間 \(M _ 2(\mathbb F _ p)\) の部分空間として \(V=\left\lbrace\begin{pmatrix}a&b\\b&a+b\end{pmatrix}\;\middle|\;a,b\in\mathbb F _ p\right\rbrace\) を考えます。 線形写像 \(f\colon V\to V\) を \(f(B)=AB\) と定めると \(f\) は全単射です(\(A\) が正則なので)。

\(p ^ 2\) 個存在する \(V\) の要素は \(O=\begin{pmatrix}0&0\\0&0\end{pmatrix}\) を除いて正則です。

証明

\(a ^ 2+ab-b ^ 2=0\) なら \(a=b=0\) を示したいです。背理法を用いて示します。

\(a=0\) もしくは \(b=0\) のとき、明らかです。

そうでないとき、両辺を \(b ^ 2\) で割って \(\left(\dfrac ab\right) ^ 2+\dfrac ab-1=0\) より \(\left(2\times\dfrac ab+1\right)^2=5\) ととなりますが、\(5\) が \({}\bmod p\) で平方非剰余であることからこのような \(a,b\) は存在しないことがわかります。\(\square\)

\(A ^ n=E\) となる最小の正の整数は \(2p+2\) です(ここで、\(E\) は単位行列です。これが正しいことは、具体的に計算してみることなどで確かめられます)。 よって、ある正則行列 \(B\in V\) からはじめて無限列 \(B,f(B),f(f(B)),\ldots\) を作ったとき、この列の最小の周期は \(2p+2\) です。

証明

\(2p+2\) を周期にもつことは明らかです。

ある \(i\lt j\) について \(f ^ i(B)=f ^ j(B)\) となったとします。

\(A,B\) が正則なので \(A ^ iB\) も正則であり、両辺に右から \((A ^ iB) ^ {-1}\) をかけることで \(E=A ^ {j-i}\) を得ることができます。 \(j-i\) は正なので、これは \(2p+2\) 以上です。\(\square\)

ここで、\(V\) に含まれる正則行列 \(B\) の行列式 \(\det B\) の値として \(1,2,\ldots,p-1\) のすべてがありえることを示します。

証明

まず、任意の \(a,b,k\in\mathbb F _ p\) について \(\det\begin{pmatrix}ka&kb\\kb&k(a+b)\end{pmatrix}=k ^ 2\det\begin{pmatrix}a&b\\b&a+b\end{pmatrix}\) が成り立ちます。

よって、ある \(B\) について \(\det B\) が \({}\bmod p\) で(\(0\) でない)平方剰余になることと、ある \(B\) について \(\det B\) が \({}\bmod p\) で平方非剰余になることを示せばよいです。

ここで、\(a=1,b=0\) とすると \(\det\begin{pmatrix}1&0\\0&1\end{pmatrix}=1\) で、\(a=2,b=1\) とすると \(\det\begin{pmatrix}2&1\\1&3\end{pmatrix}=5\) となり、示されました。\(\square\)

\(\det A=-1\) なので、\(\det(f ^ n(B))=(-1) ^ n\det B\) が成り立ちます。 つまり、列 \(B,f(B),f(f(B)),\ldots\) の要素の行列式は \(\det B\) か \(-\det B\) のどちらかに等しいです。

よって、\(i=1,2,\ldots,\dfrac{p-1}2\) それぞれについて \(\det B _ i=i\) なる \(B _ i\in V\) をとると、\(V\) はサイズ \(1\) の集合 \(\lbrace O\rbrace\) と、\(\dfrac{p-1}2\) 個のサイズ \(2p+2\) の集合 \(\lbrace B _ i,f(B _ i),f(f(B _ i)),\ldots\rbrace\) の非交和として書けることがわかります。

特に、\(\det B=\det C\) ならば、\(\lbrace B,f(B),f(f(B)),\ldots\rbrace=\lbrace C,f( C),f(f( C)),\ldots\rbrace\) が成り立つことがわかります。


よって、\(x,y\) が与えられたとき、\(N=2p+2\) における答えは \(x ^ 2+xy-y ^ 2\) の値によってのみ定まります。 特に、これが平方剰余であるかによって大きく \(2\) つに分けることができます。

  1. \(x ^ 2+xy-y ^ 2=k ^ 2\) なる \(k\in\mathbb F _ p\) が存在する場合
    • \(x=k,y=0\) に対する答えと等しいことがわかるので、これは \(0\) です。
  2. \(x ^ 2+xy-y ^ 2=k ^ 2\) なる \(k\in\mathbb F _ p\) が存在しない場合
    • このとき、\(x ^ 2+xy-y ^ 2=5k ^ 2\) なる \(k\) が存在します。ここから \(x=2k,y=k\) に対する答えと等しいことがわかります。\(x=2,y=1\) に対する答えは \(-16\) なので(実際に計算をすることでわかります)、答えは \(-16k ^ {2p+2}=-16k ^ 4\) です。

ある値 \(D\) が平方剰余であるかは \(D ^ {(p-1)/2}\) を計算することなどで求めることができます。

以上より、\(N=2p+2\) に対する答えを \(O(\log p)\) 時間で求めることができました。


2. 前計算とその高速化

公式解説では、\(\displaystyle F _ M(X)=\prod _ {i = 0} ^ {M-1}(c _ 1D ^ iX+c _ 2)\) に \(X=1,D ^ M,D ^ {2M},\ldots\) を代入することで答えを求めています。 この \(F _ M\) は \(x,y\) が入力されるごとに異なるものになりますが、少し式を変形することで入力によらない部分を取り出すことができます。 また、\(F _ M\) も性質のよい多項式であるので、少し考察をすることで \(F _ M\) を求めるのにかかる時間計算量を公式解説の \(O(M\log M)\) 時間から \(O(M+\log p)\) 時間とすることができます。 この節ではこれらを解説します。


数列 \(a _ n\) は定数 \(c _ 1,c _ 2,A,B\in\mathbb F _ p(\sqrt5)\) を用いて、\(a _ n=c _ 1A ^ n+c _ 2B ^ n\) と表すことができます。 ここで、\(c _ 1,c _ 2\) は \(x,y\) によって変わりますが、\(A,B\) は \(x,y\) によらない定数 \(\dfrac{-1\pm\sqrt5}2\) とすることができます。

ここで、\(D=\dfrac AB\) とすると \[\dfrac{a _ n}{c _ 2B ^ n}=\dfrac{c _ 1}{c _ 2}D ^ n+1\] であることから、求める総積は \[\prod _ {i=0} ^ {N-1}a _ n=c _ 2^NB ^ {N(N-1)/2}\prod _ {i=0} ^ {N-1}\left(\dfrac{c _ 1}{c _ 2}D ^ n+1\right)\] とできます。

ここで \(\displaystyle F _ n(X)=\prod _ {i=0} ^ {n-1}(D ^ iX+1)\) とすると、これは \(\displaystyle c _ 2 ^ NB ^ {N(N-1)/2}F _ N\left(\dfrac{c _ 1}{c _ 2}\right)\) です。

公式解説と同様に適当な \(M=\Theta(\sqrt p)\) をとります。 \(F _ M(X)\) の \(i\) 次 \((0\leq i\leq M)\) の係数が求められれば、\(O(M)\) 個の点 \(X=\dfrac{c _ 1}{c _ 2},\dfrac{c _ 1}{c _ 2}D ^ M,\dfrac{c _ 1}{c _ 2}D ^ {2M},\ldots\) における \(F _ M(X)\) の値を評価することは \(O(M\log M)\) 時間で可能です。 これらを用いて追加の \(O(M)\) 時間で答えを求めることも公式解説と同様にできます。

ここで、\(M\) を事前に決め打っておく(\(M\) を決め打っても最悪時の時間計算量のオーダーは変動しません)ことで \(F _ M(X)\) の係数を入力によらない値とすることができ、入力を読む前に計算を行うことができます。


次に、\(F _ M(X)\) の係数を高速に求めることを考えます。

ここで、次の恒等式を考えます。

\[(D ^ MX+1)F _ M(X)=(X+1)F _ M(DX)\]

これは、\(F _ M\) の取りかたからすぐに示すことができます。

\(\displaystyle F _ M(X)=\sum _ {i=0} ^ Mc _ iX ^ i\) をこれに代入すると、\[D ^ Mc _ {i-1}+c _ i=D ^ {i-1}c _ {i-1}+D ^ ic _ i\ (1\leq i\leq M)\] が得られます。 この式のみからは全体の定数倍の情報は得られませんが、\(F _ M(X)\) の定義から明らかに \(c _ 1=1\) なので、これと合わせるとすべての \(c _ i\) が定まります。

これを適切に実装することで \(F _ M(X)\) の \(X ^ i\) の係数を全体で \(O(M+\log p)\) 時間で求めることができます。


以上より、テストケースによらない前計算を \(O(M+\log p)\) 時間、テストケースごとに等比級数をなす点に対する多点評価をたかだか \(1\) 回と \(O(M+\log p)\) 時間の処理で答えを求めることができました。

\(F _ M(X)\) の係数の計算や、多点評価に用いる \(D ^ {\pm N(N-1)/2}\) の計算は入力をまったく必要としないため、(対応する機能がある言語では)コンパイル時に計算することができます(あるいは、埋め込みができると言い換えることもできます)。 適切に実装することで、実行時の計算を減らすことができます。

実装例は以下のようになります(コンパイル時計算が多く、コンパイルに長い時間がかかる場合があるので注意してください)。

#include <array>
#include <cassert>
#include <iostream>
#include <ranges>
#include <utility>
#include <vector>

#include <atcoder/convolution>
#include <atcoder/math>
#include <atcoder/modint>

// a + b√5 の演算
template <typename T>
constexpr std::pair<T, T> operator+(const std::pair<T, T>& lhs, const std::pair<T, T>& rhs) {
    return {lhs.first + rhs.first, lhs.second + rhs.second};
}

template <typename T>
constexpr std::pair<T, T> operator-(const std::pair<T, T>& lhs, const std::pair<T, T>& rhs) {
    return {lhs.first - rhs.first, lhs.second - rhs.second};
}

template <typename T>
constexpr std::pair<T, T> operator*(const std::pair<T, T>& lhs, const std::pair<T, T>& rhs) {
    return {lhs.first * rhs.first + 5 * lhs.second * rhs.second, lhs.first * rhs.second + lhs.second * rhs.first};
}

template <typename T>
constexpr std::pair<T, T> operator/(const std::pair<T, T>& lhs, const std::pair<T, T>& rhs) {
    const std::pair<T, T> inv{rhs.first / (rhs.first * rhs.first - 5 * rhs.second * rhs.second), -rhs.second / (rhs.first * rhs.first - 5 * rhs.second * rhs.second)};
    return lhs * inv;
}

// modint の入出力
namespace atcoder {
    template <int m>
    std::ostream& operator<<(std::ostream& os, const static_modint<m>& mint) {
        os << mint.val();
        return os;
    }

    template <int m>
    std::istream& operator>>(std::istream& is, static_modint<m>& n) {
        int tmp;
        is >> tmp;
        n = tmp;
        return is;
    }
} // namespace atcoder

// コンパイル時にできる前計算
namespace precalc {
    constexpr unsigned mod{998244353};
    constexpr unsigned period{mod * 2 + 2}; // 列の周期
    constexpr unsigned sub_period{40960}; // ブロックサイズ
    constexpr struct compile_time_precalc_t {
        std::array<unsigned long, sub_period + 1> pow_real, pow_imag;
        std::array<unsigned long, period / sub_period + sub_period + 1> triangular_pow_real, triangular_pow_imag;
        std::array<unsigned long, sub_period + 1> f_coeff_real, f_coeff_imag;

        constexpr compile_time_precalc_t() : pow_real{}, pow_imag{}, triangular_pow_real{}, triangular_pow_imag{}, f_coeff_real{}, f_coeff_imag{} {
            // 以下、D = (-(3 + √5)/2) とする

            // D^i の計算
            pow_real[0] = 1;
            pow_imag[0] = 0;
            for (unsigned i{1}; i <= sub_period; ++i) {
                pow_real[i] = (pow_real[i - 1] * 499122175 + pow_imag[i - 1] * 499122174) % mod;
                pow_imag[i] = (pow_real[i - 1] * 499122176 + pow_imag[i - 1] * 499122175) % mod;
            }

            // (D^sub_period)^(t_i) の計算
            // ここで、t_i = i × (i + 1) / 2
            unsigned long imag_temporary{};
            triangular_pow_real[0] = 1;
            triangular_pow_imag[0] = 0;
            for (unsigned i{1}, i_upto{triangular_pow_real.size()}; i < i_upto; ++i) {
                triangular_pow_real[i] = (triangular_pow_real[i - 1] * pow_real[sub_period] + 5 * triangular_pow_imag[i - 1] * pow_imag[sub_period]) % mod;
                triangular_pow_imag[i] = (triangular_pow_real[i - 1] * pow_imag[sub_period] + triangular_pow_imag[i - 1] * pow_real[sub_period]) % mod;
            }
            for (unsigned i{1}, i_upto{triangular_pow_real.size()}; i < i_upto; ++i) {
                imag_temporary = triangular_pow_imag[i];
                triangular_pow_imag[i] = (triangular_pow_real[i - 1] * triangular_pow_imag[i] + triangular_pow_imag[i - 1] * triangular_pow_real[i]) % mod;
                triangular_pow_real[i] = (triangular_pow_real[i - 1] * triangular_pow_real[i] + 5 * imag_temporary % mod * triangular_pow_imag[i - 1]) % mod;
            }

            // f(x) = Π (1 + D^i x) の係数を求める
            // f(Dx) (1 + x) = f(x) (1 + D^sub_period x) を利用して O(sub_period + log mod) 時間で計算
            {
                unsigned long state_real{1}, state_imag{0};
                for (unsigned i{sub_period}; i; --i) {
                    f_coeff_real[i] = state_real;
                    f_coeff_imag[i] = state_imag;
                    imag_temporary = state_imag;
                    state_imag = (state_real * pow_imag[i] + state_imag * (pow_real[i] - 1 + mod)) % mod;
                    state_real = (state_real * (pow_real[i] - 1 + mod) + 5 * imag_temporary % mod * pow_imag[i]) % mod;
                }
                f_coeff_real[0] = state_real;
                f_coeff_imag[0] = state_imag;
            }
            {
                unsigned long state_real{f_coeff_real[0]}, state_imag{mod - f_coeff_imag[0]};
                {
                    const auto norm{(state_real * state_real + (mod - 5) * state_imag % mod * state_imag) % mod};
                    const auto inv_norm{
                        [](unsigned long a, unsigned long n, unsigned long m) {
                            unsigned long r{1};
                            while (n) {
                                if (n & 1)
                                    r = r * a % m;
                                a = a * a % m;
                                n >>= 1;
                            }
                            return r;
                        }(norm, mod - 2, mod)
                    };
                    state_real = state_real * inv_norm % mod;
                    state_imag = state_imag * inv_norm % mod;
                }
                imag_temporary = f_coeff_imag[0];
                f_coeff_imag[0] = (f_coeff_real[0] * state_imag + imag_temporary % mod * state_real) % mod;
                f_coeff_real[0] = (f_coeff_real[0] * state_real + 5 * imag_temporary % mod * state_imag) % mod;
                for (unsigned i{1}; i <= sub_period; ++i) {
                    imag_temporary = state_imag;
                    const auto coef_real_i{pow_real.back() - pow_real[i - 1] + mod}, coef_imag_i{pow_imag.back() - pow_imag[i - 1] + mod};
                    state_imag = (state_real * coef_imag_i + state_imag * coef_real_i) % mod;
                    state_real = (state_real * coef_real_i + 5 * imag_temporary % mod * coef_imag_i) % mod;
                    imag_temporary = f_coeff_imag[i];
                    f_coeff_imag[i] = (f_coeff_real[i] * state_imag + imag_temporary % mod * state_real) % mod;
                    f_coeff_real[i] = (f_coeff_real[i] * state_real + 5 * imag_temporary % mod * state_imag) % mod;
                }
            }
        }

        // (D^sub_period)^(t_i) を返す
        std::pair<atcoder::static_modint<mod>, atcoder::static_modint<mod>> triangular_pow(unsigned i) const {
            return {atcoder::static_modint<mod>::raw(triangular_pow_real[i]), atcoder::static_modint<mod>::raw(triangular_pow_imag[i])};
        }

        // (D^sub_period)^(-t_i) を返す
        std::pair<atcoder::static_modint<mod>, atcoder::static_modint<mod>> triangular_pow_inv(unsigned i) const {
            return {atcoder::static_modint<mod>::raw(triangular_pow_real[i]), atcoder::static_modint<mod>{mod - triangular_pow_imag[i]}};
        }

        // f(x) = Π (1 + D^i x) の i 次の係数を返す
        std::pair<atcoder::static_modint<mod>, atcoder::static_modint<mod>> f_coeff(unsigned i) const {
            return {atcoder::static_modint<mod>::raw(f_coeff_real[i]), atcoder::static_modint<mod>::raw(f_coeff_imag[i])};
        }
    } compile_time_precalc{};
} // namespace precalc

int main() {
    using namespace std;
    using modint = atcoder::static_modint<precalc::mod>;

    // chirp Z 変換の計算
    const auto chirp_z_convolution{
        [](auto&& a_real, auto&& a_imag, auto&& b_real, auto&& b_imag) {
            assert(a_real.size() == a_imag.size() && b_real.size() == b_imag.size());
            const auto N{b_real.size()};
            const auto length{bit_ceil(N)};
            a_real.resize(length);
            a_imag.resize(length);
            b_real.resize(length);
            b_imag.resize(length);

            atcoder::internal::butterfly(a_real);
            atcoder::internal::butterfly(a_imag);
            atcoder::internal::butterfly(b_real);
            atcoder::internal::butterfly(b_imag);

            vector<modint> result_real(length), result_imag(length);
            for (unsigned i{}; i < length; ++i) {
                result_real[i] = a_real[i] * b_real[i] + 5 * a_imag[i] * b_imag[i];
                result_imag[i] = a_real[i] * b_imag[i] + a_imag[i] * b_real[i];
            }
            atcoder::internal::butterfly_inv(result_real);
            atcoder::internal::butterfly_inv(result_imag);
            const auto iz{modint(length).inv()};
            for (unsigned i{}; i < length; ++i) {
                result_real[i] *= iz;
                result_imag[i] *= iz;
            }
            return make_pair(result_real, result_imag);
        }
    };

    unsigned T;
    cin >> T;
    for (unsigned _{}; _ < T; ++_)
        cout << [&chirp_z_convolution]() -> unsigned {
            unsigned long N;
            modint x, y;
            cin >> N >> x >> y;

            // 周期部分の答えを求める
            // x^2 + xy - y^2 が平方剰余なら、998244354 回以内に 0 が出現する
            const modint det{x * x + x * y - y * y};
            const auto is_quadratic_residue{det.pow((precalc::mod - 1) / 2) == 1};
            if (is_quadratic_residue && 2 * N >= precalc::period)
                return 0;

            // 平方非剰余なら、一周期の値は -16/25 det^2
            modint ans{(det * det * 878455030).pow(N / precalc::period)}, imag{};
            N %= precalc::period;

            // ブロックからはみ出た部分を計算する O(sub_period) 時間
            const auto remainder{N % precalc::sub_period};
            N /= precalc::sub_period;
            for (unsigned i{}; i < remainder; ++i) {
                ans *= x;
                swap(x, y);
                y += x;
            }
            if (ans == 0)
                return 0;

            if (N) {
                if (x == 0 || y == 0)
                    return 0;

                // ブロックの計算
                // a ((1 + √5)/2)^n + b ((-1 + √5)/2)^n = a_n となるような a, b を求める
                const modint a{x * modint::raw(499122177)}, b{(2 * y - x) * modint::raw(299473306)};
                const auto& c1{pair{a, b} / pair{a, -b}};

                // 係数列を求める
                // 1. c_n a^n D^{-t_n}
                vector<modint> seq1_pow_real(precalc::sub_period + 1), seq1_pow_imag(precalc::sub_period + 1);
                {
                    pair<modint, modint> state{1, 0};
                    for (unsigned i{}; i <= precalc::sub_period; ++i) {
                        tie(seq1_pow_real[i], seq1_pow_imag[i]) = state * precalc::compile_time_precalc.f_coeff(i) * precalc::compile_time_precalc.triangular_pow_inv(i);
                        state = state * c1;
                    }
                }
                // を反転したもの
                ranges::reverse(seq1_pow_real);
                ranges::reverse(seq1_pow_imag);

                // 2. D^{t_n}
                vector<modint> seq2_pow_real(precalc::compile_time_precalc.triangular_pow_real.begin(), precalc::compile_time_precalc.triangular_pow_real.begin() + precalc::sub_period + N);
                vector<modint> seq2_pow_imag(precalc::compile_time_precalc.triangular_pow_imag.begin(), precalc::compile_time_precalc.triangular_pow_imag.begin() + precalc::sub_period + N);

                // 0, 1, ..., N-1 次の係数の総積をかける
                const auto& [result_real, result_imag] = chirp_z_convolution(move(seq1_pow_real), move(seq1_pow_imag), move(seq2_pow_real), move(seq2_pow_imag));
                pair<modint, modint> state{1, 0};
                for (unsigned i{}; i < N; ++i)
                    tie(ans, imag) = pair{ans, imag} * pair{result_real[i + precalc::sub_period], result_imag[i + precalc::sub_period]};

                const auto pow{
                    [](auto a, const auto& b, unsigned long n) {
                        auto r{b};
                        while (n) {
                            if (n & 1)
                                r = r * a;
                            a = a * a;
                            n >>= 1;
                        }
                        return r;
                    }
                };

                // 残った係数をかける
                ans = pow(pair{modint::raw(998244354 / 2), modint::raw(998244352 / 2)}, pow(pair{a, -b}, pair{ans, imag}, N * precalc::sub_period), (N * precalc::sub_period * (N * precalc::sub_period - 1) / 2 + N * (N + 1) * (N - 1) / 3 * precalc::sub_period) % precalc::period).first;
            }
            return ans.val();
        }() << endl;
    return 0;
}

投稿日時:
最終更新: