Submission #70662114


Source Code Expand

#include <bits/stdc++.h>
using namespace std;

namespace std {

template <int D, typename T>
struct Vec : public vector<Vec<D - 1, T>> {
    static_assert(D >= 1, "Dimension must be positive");
    template <typename... Args>
    Vec(int n = 0, Args... args) : vector<Vec<D - 1, T>>(n, Vec<D - 1, T>(args...)) {}
};

template <typename T>
struct Vec<1, T> : public vector<T> {
    Vec(int n = 0, T val = T()) : std::vector<T>(n, val) {}
};

/* Example
    Vec<4, int64_t> f(n, k, 2, 2); // = f[n][k][2][2];
    Vec<2, int> adj(n); // graph
*/

template <class Fun>
class y_combinator_result {
    Fun fun_;

   public:
    template <class T>
    explicit y_combinator_result(T&& fun) : fun_(std::forward<T>(fun)) {}

    template <class... Args>
    decltype(auto) operator()(Args&&... args) {
        return fun_(std::ref(*this), std::forward<Args>(args)...);
    }
};

template <class Fun>
decltype(auto) y_combinator(Fun&& fun) {
    return y_combinator_result<std::decay_t<Fun>>(std::forward<Fun>(fun));
}

/* Example
    auto fun = y_combinator([&](auto self, int x) -> void {
        self(x + 1);
    });
*/

}  // namespace std

template <typename T>
T inverse(T a, T m) {
    T u = 0, v = 1;
    while (a != 0) {
        T t = m / a;
        m -= t * a;
        swap(a, m);
        u -= t * v;
        swap(u, v);
    }
    assert(m == 1);
    return u;
}

template <typename T>
class Modular {
   public:
    using Type = typename decay<decltype(T::value)>::type;

    constexpr Modular() : value() {}
    template <typename U>
    Modular(const U& x) {
        value = normalize(x);
    }

    template <typename U>
    static Type normalize(const U& x) {
        Type v;
        if (-mod() <= x && x < mod())
            v = static_cast<Type>(x);
        else
            v = static_cast<Type>(x % mod());
        if (v < 0) v += mod();
        return v;
    }

    const Type& operator()() const { return value; }
    template <typename U>
    explicit operator U() const { return static_cast<U>(value); }
    constexpr static Type mod() { return T::value; }

    Modular& operator+=(const Modular& other) {
        if ((value += other.value) >= mod()) value -= mod();
        return *this;
    }
    Modular& operator-=(const Modular& other) {
        if ((value -= other.value) < 0) value += mod();
        return *this;
    }
    template <typename U>
    Modular& operator+=(const U& other) { return *this += Modular(other); }
    template <typename U>
    Modular& operator-=(const U& other) { return *this -= Modular(other); }
    Modular& operator++() { return *this += 1; }
    Modular& operator--() { return *this -= 1; }
    Modular operator++(int) {
        Modular result(*this);
        *this += 1;
        return result;
    }
    Modular operator--(int) {
        Modular result(*this);
        *this -= 1;
        return result;
    }
    Modular operator-() const { return Modular(-value); }

    template <typename U = T>
    typename enable_if<is_same<typename Modular<U>::Type, int>::value, Modular>::type& operator*=(const Modular& rhs) {
        value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
        return *this;
    }
    template <typename U = T>
    typename enable_if<is_same<typename Modular<U>::Type, long long>::value, Modular>::type& operator*=(const Modular& rhs) {
        long long q = static_cast<long long>(static_cast<long double>(value) * rhs.value / mod());
        value = normalize(value * rhs.value - q * mod());
        return *this;
    }
    template <typename U = T>
    typename enable_if<!is_integral<typename Modular<U>::Type>::value, Modular>::type& operator*=(const Modular& rhs) {
        value = normalize(value * rhs.value);
        return *this;
    }

    Modular& operator/=(const Modular& other) { return *this *= Modular(inverse(other.value, mod())); }

    friend const Type& abs(const Modular& x) { return x.value; }

    template <typename U>
    friend bool operator==(const Modular<U>& lhs, const Modular<U>& rhs);

    template <typename U>
    friend bool operator<(const Modular<U>& lhs, const Modular<U>& rhs);

    template <typename V, typename U>
    friend V& operator>>(V& stream, Modular<U>& number);

   private:
    Type value;
};

template <typename T>
bool operator==(const Modular<T>& lhs, const Modular<T>& rhs) { return lhs.value == rhs.value; }
template <typename T, typename U>
bool operator==(const Modular<T>& lhs, U rhs) { return lhs == Modular<T>(rhs); }
template <typename T, typename U>
bool operator==(U lhs, const Modular<T>& rhs) { return Modular<T>(lhs) == rhs; }

template <typename T>
bool operator!=(const Modular<T>& lhs, const Modular<T>& rhs) { return !(lhs == rhs); }
template <typename T, typename U>
bool operator!=(const Modular<T>& lhs, U rhs) { return !(lhs == rhs); }
template <typename T, typename U>
bool operator!=(U lhs, const Modular<T>& rhs) { return !(lhs == rhs); }

template <typename T>
bool operator<(const Modular<T>& lhs, const Modular<T>& rhs) { return lhs.value < rhs.value; }

template <typename T>
Modular<T> operator+(const Modular<T>& lhs, const Modular<T>& rhs) { return Modular<T>(lhs) += rhs; }
template <typename T, typename U>
Modular<T> operator+(const Modular<T>& lhs, U rhs) { return Modular<T>(lhs) += rhs; }
template <typename T, typename U>
Modular<T> operator+(U lhs, const Modular<T>& rhs) { return Modular<T>(lhs) += rhs; }

template <typename T>
Modular<T> operator-(const Modular<T>& lhs, const Modular<T>& rhs) { return Modular<T>(lhs) -= rhs; }
template <typename T, typename U>
Modular<T> operator-(const Modular<T>& lhs, U rhs) { return Modular<T>(lhs) -= rhs; }
template <typename T, typename U>
Modular<T> operator-(U lhs, const Modular<T>& rhs) { return Modular<T>(lhs) -= rhs; }

template <typename T>
Modular<T> operator*(const Modular<T>& lhs, const Modular<T>& rhs) { return Modular<T>(lhs) *= rhs; }
template <typename T, typename U>
Modular<T> operator*(const Modular<T>& lhs, U rhs) { return Modular<T>(lhs) *= rhs; }
template <typename T, typename U>
Modular<T> operator*(U lhs, const Modular<T>& rhs) { return Modular<T>(lhs) *= rhs; }

template <typename T>
Modular<T> operator/(const Modular<T>& lhs, const Modular<T>& rhs) { return Modular<T>(lhs) /= rhs; }
template <typename T, typename U>
Modular<T> operator/(const Modular<T>& lhs, U rhs) { return Modular<T>(lhs) /= rhs; }
template <typename T, typename U>
Modular<T> operator/(U lhs, const Modular<T>& rhs) { return Modular<T>(lhs) /= rhs; }

template <typename T, typename U>
Modular<T> power(const Modular<T>& a, const U& b) {
    assert(b >= 0);
    Modular<T> x = a, res = 1;
    U p = b;
    while (p > 0) {
        if (p & 1) res *= x;
        x *= x;
        p >>= 1;
    }
    return res;
}

template <typename T>
bool IsZero(const Modular<T>& number) {
    return number() == 0;
}

template <typename T>
string to_string(const Modular<T>& number) {
    return to_string(number());
}

// U == std::ostream? but done this way because of fastoutput
template <typename U, typename T>
U& operator<<(U& stream, const Modular<T>& number) {
    return stream << number();
}

// U == std::istream? but done this way because of fastinput
template <typename U, typename T>
U& operator>>(U& stream, Modular<T>& number) {
    typename common_type<typename Modular<T>::Type, long long>::type x;
    stream >> x;
    number.value = Modular<T>::normalize(x);
    return stream;
}

/*
using ModType = int;

struct VarMod { static ModType value; };
ModType VarMod::value;
ModType& md = VarMod::value;
using Mint = Modular<VarMod>;
*/

constexpr int md = 998244353;
using Mint = Modular<std::integral_constant<decay<decltype(md)>::type, md>>;

vector<Mint> fact(1, 1);
vector<Mint> inv_fact(1, 1);

Mint C(int n, int k) {
    if (k < 0 || k > n) {
        return 0;
    }
    while ((int)fact.size() < n + 1) {
        fact.push_back(fact.back() * (int)fact.size());
        inv_fact.push_back(1 / fact.back());
    }
    return fact[n] * inv_fact[k] * inv_fact[n - k];
}

template <typename T>
class NTT {
   public:
    using Type = typename decay<decltype(T::value)>::type;

    static Type md;
    static Modular<T> root;
    static int base;
    static int max_base;
    static vector<Modular<T>> roots;
    static vector<int> rev;

    static void clear() {
        root = 0;
        base = 0;
        max_base = 0;
        roots.clear();
        rev.clear();
    }

    static void init() {
        md = T::value;
        assert(md >= 3 && md % 2 == 1);
        auto tmp = md - 1;
        max_base = 0;
        while (tmp % 2 == 0) {
            tmp /= 2;
            max_base++;
        }
        root = 2;
        while (power(root, (md - 1) >> 1) == 1) {
            root++;
        }
        assert(power(root, md - 1) == 1);
        root = power(root, (md - 1) >> max_base);
        base = 1;
        rev = {0, 1};
        roots = {0, 1};
    }

    static void ensure_base(int nbase) {
        if (md != T::value) {
            clear();
        }
        if (roots.empty()) {
            init();
        }
        if (nbase <= base) {
            return;
        }
        assert(nbase <= max_base);
        rev.resize(1 << nbase);
        for (int i = 0; i < (1 << nbase); i++) {
            rev[i] = (rev[i >> 1] >> 1) + ((i & 1) << (nbase - 1));
        }
        roots.resize(1 << nbase);
        while (base < nbase) {
            Modular<T> z = power(root, 1 << (max_base - 1 - base));
            for (int i = 1 << (base - 1); i < (1 << base); i++) {
                roots[i << 1] = roots[i];
                roots[(i << 1) + 1] = roots[i] * z;
            }
            base++;
        }
    }

    static void fft(vector<Modular<T>>& a) {
        int n = (int)a.size();
        assert((n & (n - 1)) == 0);
        int zeros = __builtin_ctz(n);
        ensure_base(zeros);
        int shift = base - zeros;
        for (int i = 0; i < n; i++) {
            if (i < (rev[i] >> shift)) {
                swap(a[i], a[rev[i] >> shift]);
            }
        }
        for (int k = 1; k < n; k <<= 1) {
            for (int i = 0; i < n; i += 2 * k) {
                for (int j = 0; j < k; j++) {
                    Modular<T> x = a[i + j];
                    Modular<T> y = a[i + j + k] * roots[j + k];
                    a[i + j] = x + y;
                    a[i + j + k] = x - y;
                }
            }
        }
    }

    static vector<Modular<T>> multiply(vector<Modular<T>> a, vector<Modular<T>> b) {
        if (a.empty() || b.empty()) {
            return {};
        }
        int eq = (a == b);
        int need = (int)a.size() + (int)b.size() - 1;
        int nbase = 0;
        while ((1 << nbase) < need) nbase++;
        ensure_base(nbase);
        int sz = 1 << nbase;
        a.resize(sz);
        b.resize(sz);
        fft(a);
        if (eq)
            b = a;
        else
            fft(b);
        Modular<T> inv_sz = 1 / static_cast<Modular<T>>(sz);
        for (int i = 0; i < sz; i++) {
            a[i] *= b[i] * inv_sz;
        }
        reverse(a.begin() + 1, a.end());
        fft(a);
        a.resize(need);
        return a;
    }
};

template <typename T>
typename NTT<T>::Type NTT<T>::md;
template <typename T>
Modular<T> NTT<T>::root;
template <typename T>
int NTT<T>::base;
template <typename T>
int NTT<T>::max_base;
template <typename T>
vector<Modular<T>> NTT<T>::roots;
template <typename T>
vector<int> NTT<T>::rev;

template <typename T>
vector<Modular<T>> inverse(const vector<Modular<T>>& a) {
    assert(!a.empty());
    int n = (int)a.size();
    vector<Modular<T>> b = {1 / a[0]};
    while ((int)b.size() < n) {
        vector<Modular<T>> x(a.begin(), a.begin() + min(a.size(), b.size() << 1));
        x.resize(b.size() << 1);
        b.resize(b.size() << 1);
        vector<Modular<T>> c = b;
        NTT<T>::fft(c);
        NTT<T>::fft(x);
        Modular<T> inv = 1 / static_cast<Modular<T>>((int)x.size());
        for (int i = 0; i < (int)x.size(); i++) {
            x[i] *= c[i] * inv;
        }
        reverse(x.begin() + 1, x.end());
        NTT<T>::fft(x);
        rotate(x.begin(), x.begin() + (x.size() >> 1), x.end());
        fill(x.begin() + (x.size() >> 1), x.end(), 0);
        NTT<T>::fft(x);
        for (int i = 0; i < (int)x.size(); i++) {
            x[i] *= c[i] * inv;
        }
        reverse(x.begin() + 1, x.end());
        NTT<T>::fft(x);
        for (int i = 0; i < ((int)x.size() >> 1); i++) {
            b[i + ((int)x.size() >> 1)] = -x[i];
        }
    }
    b.resize(n);
    return b;
}

template <typename T>
vector<Modular<T>> inverse_old(vector<Modular<T>> a) {
    assert(!a.empty());
    int n = (int)a.size();
    if (n == 1) {
        return {1 / a[0]};
    }
    int m = (n + 1) >> 1;
    vector<Modular<T>> b = inverse(vector<Modular<T>>(a.begin(), a.begin() + m));
    int need = n << 1;
    int nbase = 0;
    while ((1 << nbase) < need) {
        ++nbase;
    }
    NTT<T>::ensure_base(nbase);
    int size = 1 << nbase;
    a.resize(size);
    b.resize(size);
    NTT<T>::fft(a);
    NTT<T>::fft(b);
    Modular<T> inv = 1 / static_cast<Modular<T>>(size);
    for (int i = 0; i < size; ++i) {
        a[i] = (2 - a[i] * b[i]) * b[i] * inv;
    }
    reverse(a.begin() + 1, a.end());
    NTT<T>::fft(a);
    a.resize(n);
    return a;
}

template <typename T>
vector<Modular<T>> operator*(const vector<Modular<T>>& a, const vector<Modular<T>>& b) {
    if (a.empty() || b.empty()) {
        return {};
    }
    if (min(a.size(), b.size()) < 150) {
        vector<Modular<T>> c(a.size() + b.size() - 1, 0);
        for (int i = 0; i < (int)a.size(); i++) {
            for (int j = 0; j < (int)b.size(); j++) {
                c[i + j] += a[i] * b[j];
            }
        }
        return c;
    }
    return NTT<T>::multiply(a, b);
}

template <typename T>
vector<Modular<T>>& operator*=(vector<Modular<T>>& a, const vector<Modular<T>>& b) {
    return a = a * b;
}

template <typename T>
vector<Modular<T>> operator^(vector<Modular<T>> a, int64_t b) {
    vector<Modular<T>> res(1, 1);
    for (; b; b >>= 1, a *= a) {
        if (b & 1) res *= a;
    }
    return res;
}

template <typename T>
vector<Modular<T>>& operator^=(vector<Modular<T>>& a, int64_t b) {
    return a = a ^ b;
}

int32_t main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    int n, m, s;
    cin >> n >> m >> s;
    vector<Mint> a(m + 1);
    for (int i = 0; i <= m; i++) a[i] = 1;
    vector<Mint> res(1, 1);
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        if (res.size() > s + 1) {
            res.resize(s + 1);
        }
        a *= a;
        if (a.size() > s + 1) {
            a.resize(s + 1);
        }
        n >>= 1;
    }
    cout << res[s] << "\n";
}

Submission Info

Submission Time
Task C - Sequence
User lmqzzz
Language C++ 20 (gcc 12.2)
Score 3
Code Size 15280 Byte
Status AC
Exec Time 885 ms
Memory 18956 KiB

Compile Error

Main.cpp: In function ‘int32_t main()’:
Main.cpp:503:24: warning: comparison of integer expressions of different signedness: ‘std::vector<Modular<std::integral_constant<int, 998244353> > >::size_type’ {aka ‘long unsigned int’} and ‘int’ [-Wsign-compare]
  503 |         if (res.size() > s + 1) {
      |             ~~~~~~~~~~~^~~~~~~
Main.cpp:507:22: warning: comparison of integer expressions of different signedness: ‘std::vector<Modular<std::integral_constant<int, 998244353> > >::size_type’ {aka ‘long unsigned int’} and ‘int’ [-Wsign-compare]
  507 |         if (a.size() > s + 1) {
      |             ~~~~~~~~~^~~~~~~

Judge Result

Set Name Sample All
Score / Max Score 0 / 0 3 / 3
Status
AC × 2
AC × 9
Set Name Test Cases
Sample 00_sample_00.txt, 00_sample_01.txt
All 00_sample_00.txt, 00_sample_01.txt, 01_random_00.txt, 01_random_01.txt, 02_m_small_00.txt, 02_m_small_01.txt, 02_m_small_02.txt, 03_max_00.txt, 04_min_00.txt
Case Name Status Exec Time Memory
00_sample_00.txt AC 1 ms 3600 KiB
00_sample_01.txt AC 152 ms 9828 KiB
01_random_00.txt AC 372 ms 10616 KiB
01_random_01.txt AC 885 ms 18956 KiB
02_m_small_00.txt AC 94 ms 15784 KiB
02_m_small_01.txt AC 156 ms 17784 KiB
02_m_small_02.txt AC 187 ms 16340 KiB
03_max_00.txt AC 749 ms 17820 KiB
04_min_00.txt AC 1 ms 3608 KiB