Submission #25075718


Source Code Expand

#define _SILENCE_CXX17_C_HEADER_DEPRECATION_WARNING
#define _CRT_SECURE_NO_WARNINGS
#include <bits/stdc++.h>
using namespace std;
#define db double
#ifndef ONLINE_JUDGE
inline int __builtin_clz(int v) { // 返回前导0的个数
    return __lzcnt(v);
}
inline int __builtin_ctz(int v) { // 返回末尾0的个数
    if (v == 0) {
        return 0;
    }
    __asm {
        bsf eax, dword ptr[v];
    }
}
inline int __builtin_popcount(int v) { // 返回二进制中1的个数
    return __popcnt(v);
}
#endif
struct Complex {
    db real, imag;
    Complex(db x = 0, db y = 0) :real(x), imag(y) {}
    Complex& operator+=(const Complex& rhs) {
        real += rhs.real; imag += rhs.imag;
        return *this;
    }
    Complex& operator-=(const Complex& rhs) {
        real -= rhs.real; imag -= rhs.imag;
        return *this;
    }
    Complex& operator*=(const Complex& rhs) {
        db t_real = real * rhs.real - imag * rhs.imag;
        imag = real * rhs.imag + imag * rhs.real;
        real = t_real;
        return *this;
    }
    Complex& operator/=(double x) {
        real /= x, imag /= x;
        return *this;
    }
    friend Complex operator + (const Complex& a, const Complex& b) { return Complex(a) += b; }
    friend Complex operator - (const Complex& a, const Complex& b) { return Complex(a) -= b; }
    friend Complex operator * (const Complex& a, const Complex& b) { return Complex(a) *= b; }
    friend Complex operator / (const Complex& a, const db& b) { return Complex(a) /= b; }
    Complex power(long long p) const {
        assert(p >= 0);
        Complex a = *this, res = { 1,0 };
        while (p > 0) {
            if (p & 1) res = res * a;
            a = a * a;
            p >>= 1;
        }
        return res;
    }
    static long long val(double x) { return x < 0 ? x - 0.5 : x + 0.5; }
    inline long long Real() const { return val(real); }
    inline long long Imag() const { return val(imag); }
    Complex conj()const { return Complex(real, -imag); }
    explicit operator int()const { return Real(); }
    friend ostream& operator<<(ostream& stream, const Complex& m) {
        return stream << complex<db>(m.real, m.imag);
    }
};
constexpr int MOD = 998244353;
constexpr int Phi_MOD = 998244352;
inline int exgcd(int a, int md = MOD) {
    a %= md;
    if (a < 0) a += md;
    int b = md, u = 0, v = 1;
    while (a) {
        int t = b / a;
        b -= t * a; swap(a, b);
        u -= t * v; swap(u, v);
    }
    assert(b == 1);
    if (u < 0) u += md;
    return u;
}
inline int add(int a, int b) { return a + b >= MOD ? a + b - MOD : a + b; }
inline int sub(int a, int b) { return a - b < 0 ? a - b + MOD : a - b; }
inline int mul(int a, int b) { return 1LL * a * b % MOD; }
inline int powmod(int a, long long b) {
    int res = 1;
    while (b > 0) {
        if (b & 1) res = mul(res, a);
        a = mul(a, a);
        b >>= 1;
    }
    return res;
}

vector<int> inv, fac, ifac;
void prepare_factorials(int maximum) {
    inv.assign(maximum + 1, 1);
    // Make sure MOD is prime, which is necessary for the inverse algorithm below.
    for (int p = 2; p * p <= MOD; p++)
        assert(MOD % p != 0);
    for (int i = 2; i <= maximum; i++)
        inv[i] = mul(inv[MOD % i], (MOD - MOD / i));

    fac.resize(maximum + 1);
    ifac.resize(maximum + 1);
    fac[0] = ifac[0] = 1;

    for (int i = 1; i <= maximum; i++) {
        fac[i] = mul(i, fac[i - 1]);
        ifac[i] = mul(inv[i], ifac[i - 1]);
    }
}
namespace FFT {
    vector<Complex> roots = { Complex(0, 0), Complex(1, 0) };
    vector<int> bit_reverse;
    int max_size = 1 << 20;
    const long double pi = acosl(-1.0l);
    constexpr int FFT_CUTOFF = 150;
    inline bool is_power_of_two(int n) { return (n & (n - 1)) == 0; }
    inline int round_up_power_two(int n) {
        assert(n > 0);
        while (n & (n - 1)) {
            n = (n | (n - 1)) + 1;
        }
        return n;
    }
    // Given n (a power of two), finds k such that n == 1 << k.
    inline int get_length(int n) {
        assert(is_power_of_two(n));
        return __builtin_ctz(n);
    }
    // Rearranges the indices to be sorted by lowest bit first, then second lowest, etc., rather than highest bit first.
    // This makes even-odd div-conquer much easier.
    void bit_reorder(int n, vector<Complex>& values) {
        if ((int)bit_reverse.size() != n) {
            bit_reverse.assign(n, 0);
            int length = get_length(n);
            for (int i = 0; i < n; i++) {
                bit_reverse[i] = (bit_reverse[i >> 1] >> 1) + ((i & 1) << (length - 1));
            }
        }
        for (int i = 0; i < n; i++) {
            if (i < bit_reverse[i]) {
                swap(values[i], values[bit_reverse[i]]);
            }
        }
    }
    void prepare_roots(int n) {
        assert(n <= max_size);
        if ((int)roots.size() >= n)
            return;
        int length = get_length(roots.size());
        roots.resize(n);
        // The roots array is set up such that for a given power of two n >= 2, roots[n / 2] through roots[n - 1] are
        // the first half of the n-th primitive roots of MOD.
        while (1 << length < n) {
            for (int i = 1 << (length - 1); i < 1 << length; i++) {
                roots[2 * i] = roots[i];
                long double angle = pi * (2 * i + 1) / (1 << length);
                roots[2 * i + 1] = Complex(-cos(angle), -sin(angle));
            }
            length++;
        }
    }
    void fft_iterative(int N, vector<Complex>& values) {
        assert(is_power_of_two(N));
        prepare_roots(N);
        bit_reorder(N, values);
        for (int n = 1; n < N; n *= 2) {
            for (int start = 0; start < N; start += 2 * n) {
                for (int i = 0; i < n; i++) {
                    Complex& even = values[start + i];
                    Complex odd = values[start + n + i] * roots[n + i];
                    values[start + n + i] = even - odd;
                    values[start + i] = even + odd;
                }
            }
        }
    }
    vector<long long> multiply(vector<int> a, vector<int> b) { // 普通FFT
        int n = a.size();
        int m = b.size();
        if (min(n, m) < FFT_CUTOFF) {
            vector<long long> res(n + m - 1);
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < m; j++) {
                    res[i + j] += 1LL * a[i] * b[j];
                }
            }
            return res;
        }
        int N = round_up_power_two(n + m - 1);
        vector<Complex> tmp(N);
        for (int i = 0; i < a.size(); i++) tmp[i].real = a[i];
        for (int i = 0; i < b.size(); i++) tmp[i].imag = b[i];
        fft_iterative(N, tmp);
        for (int i = 0; i < N; i++) tmp[i] = tmp[i] * tmp[i];
        reverse(tmp.begin() + 1, tmp.end());
        fft_iterative(N, tmp);
        vector<long long> res(n + m - 1);
        for (int i = 0; i < res.size(); i++) {
            res[i] = tmp[i].imag / 2 / N + 0.5;
        }
        return res;
    }
    vector<int> mod_multiply(vector<int> a, vector<int> b, int lim = max_size) { // 任意模数FFT
        int n = a.size();
        int m = b.size();
        if (min(n, m) < FFT_CUTOFF) {
            vector<int> res(n + m - 1);
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < m; j++) {
                    res[i + j] += 1LL * a[i] * b[j] % MOD;
                    res[i + j] %= MOD;
                }
            }
            return res;
        }
        int N = round_up_power_two(n + m - 1);
        N = min(N, lim);
        vector<Complex> P(N);
        vector<Complex> Q(N);
        for (int i = 0; i < n; i++) {
            P[i] = Complex(a[i] >> 15, a[i] & 0x7fff);
        }
        for (int i = 0; i < m; i++) {
            Q[i] = Complex(b[i] >> 15, b[i] & 0x7fff);
        }
        fft_iterative(N, P);
        fft_iterative(N, Q);
        vector<Complex>A(N), B(N), C(N), D(N);
        for (int i = 0; i < N; i++) {
            Complex P2 = P[(N - i) & (N - 1)].conj();
            A[i] = (P2 + P[i]) * Complex(0.5, 0),
                B[i] = (P2 - P[i]) * Complex(0, 0.5);
            Complex Q2 = Q[(N - i) & (N - 1)].conj();
            C[i] = (Q2 + Q[i]) * Complex(0.5, 0),
                D[i] = (Q2 - Q[i]) * Complex(0, 0.5);
        }
        for (int i = 0; i < N; i++) {
            P[i] = (A[i] * C[i]) + (B[i] * D[i]) * Complex(0, 1),
                Q[i] = (A[i] * D[i]) + (B[i] * C[i]) * Complex(0, 1);
        }
        reverse(P.begin() + 1, P.end());
        reverse(Q.begin() + 1, Q.end());
        fft_iterative(N, P);
        fft_iterative(N, Q);
        for (int i = 0; i < N; i++) {
            P[i] /= N, Q[i] /= N;
        }
        int size = min(n + m - 1, lim);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            long long ac = P[i].Real() % MOD, bd = P[i].Imag() % MOD,
                ad = Q[i].Real() % MOD, bc = Q[i].Imag() % MOD;
            res[i] = ((ac << 30) + bd + ((ad + bc) << 15)) % MOD;
        }
        return res.resize(n + m - 1), res;
    }
    vector<int> mod_inv(vector<int> a) { // 多项式逆
        int n = a.size();
        int N = round_up_power_two(a.size());
        a.resize(N * 2);
        vector<int> res(1);
        res[0] = exgcd(a[0]);
        for (int i = 2; i <= N; i <<= 1) {
            vector<int> tmp(a.begin(), a.begin() + i);
            int n = (i << 1);
            tmp = mod_multiply(tmp, mod_multiply(res, res, n), n);
            res.resize(i);
            for (int j = 0; j < i; j++) {
                res[j] = add(res[j], sub(res[j], tmp[j]));
            }
        }
        res.resize(n);
        return res;
    }
    vector<int> integral(vector<int> a) { // 多项式积分
        assert(a.size() <= inv.size());
        a.push_back(0);
        for (int i = (int)a.size() - 1; i >= 1; i--) {
            a[i] = mul(a[i - 1], inv[i]);
        }
        return a;
    }
    vector<int> differential(vector<int> a) { // 多项式求导
        for (int i = 0; i < (int)a.size() - 1; i++) {
            a[i] = mul(i + 1, a[i + 1]);
        }
        a.pop_back();
        return a;
    }
    vector<int> ln(vector<int> a) { // 多项式对数函数
        assert((int)a[0] == 1);
        auto b = mod_multiply(differential(a), mod_inv(a));
        b = integral(b);
        b[0] = 0;
        return b;
    }
    vector<int> exp(vector<int> a) { // 多项式指数函数
        int N = round_up_power_two(a.size());
        int n = a.size();
        a.resize(N * 2);
        vector<int> res{ 1 };
        for (int i = 2; i <= N; i <<= 1) {
            auto tmp = res;
            tmp.resize(i);
            tmp = ln(tmp);
            for (int j = 0; j < i; j++) {
                tmp[j] = sub(a[j], tmp[j]);
            }
            tmp[0] = add(tmp[0], 1);
            res.resize(i);
            res = mod_multiply(res, tmp, i << 1);
            fill(res.begin() + i, res.end(), 0);
        }
        res.resize(n);
        return res;
    }
    // Multiplies many polynomials whose total degree is n in O(n log^2 n).
    vector<int> mod_multiply_all(const vector<vector<int>>& polynomials) {
        if (polynomials.empty())
            return { 1 };
        struct compare_size {
            bool operator()(const vector<int>& x, const vector<int>& y) {
                return x.size() > y.size();
            }
        };
        priority_queue<vector<int>, vector<vector<int>>, compare_size> pq(polynomials.begin(), polynomials.end());
        while (pq.size() > 1) {
            vector<int> a = pq.top(); pq.pop();
            vector<int> b = pq.top(); pq.pop();
            pq.push(mod_multiply(a, b));
        }
        return pq.top();
    }
    tuple<int, int, bool> power_reduction(string s, int n) { // 多项式快速幂预处理
        int p = 0, q = 0; bool zero = false;
        for (int i = 0; i < s.length(); i++) {
            p = mul(p, 10);
            p = add(p, s[i] - '0');
            q = 1LL * q * 10 % Phi_MOD; // Phi_MOD 是MOD的欧拉函数值
            q = (q + s[i] - '0');
            if (q >= Phi_MOD) q -= Phi_MOD;
            if (q >= (int)n) zero = true;
        }
        return { p,q,zero };
    }
    vector<int> power(vector<int> a, string s) { // 多项式快速幂 a^s O(nlogn)
        int n = a.size();
        auto [p, q, zero] = power_reduction(s, (int)a.size()); // 不需要降幂的话可以省去这部分
        if (a[0] == 1) {
            auto res = ln(a);
            while ((int)res.size() > n) res.pop_back();
            for (auto& i : res) {
                i = mul(p, i);
            }
            res = exp(res);
            return res;
        } else {
            int mn = -1;
            vector<int> copy_a;
            for (int i = 0; i < (int)a.size(); i++) {
                if (a[i]) {
                    mn = i;
                    break;
                }
            }
            if ((mn == -1) || (mn && (zero || (1LL * mn * p > n)))) { // a中所有元素都是0 或 偏移过大
                return vector<int>(n, 0);
            }
            int inverse_amin = exgcd(a[mn]);
            for (int i = mn; i < n; i++) {
                copy_a.emplace_back(mul(a[i], inverse_amin));
            }
            copy_a = ln(copy_a);
            while ((int)copy_a.size() > n) copy_a.pop_back();
            for (auto& i : copy_a) {
                i = mul(i, p);
            }
            copy_a = exp(copy_a);
            vector<int> res(n, 0);
            // shift是偏移量 power_k 是a_min^q(q是扩展欧拉定理降出来的幂次)
            int shift = mn * p, power_k = powmod(a[mn], q);
            for (int i = 0; i + shift < n; i++) {
                res[i + shift] = mul(copy_a[i], power_k);
            }
            return res;
        }
    }
    vector<long long> sub_convolution(vector<int> a, vector<int> b) { // 减法卷积 只保留非负次项
        int n = b.size();
        reverse(b.begin(), b.end());
        auto res = multiply(a, b);
        return vector<long long>(res.begin() + n - 1, res.end());
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int n, m, res = 0;
    cin >> n >> m;
    vector<int> a(m + 1, 0);
    vector<int> b(n);
    a[0] = 1;
    for (auto& i : b) {
        cin >> i;
    }
    for (auto i : b) {
        vector<int> c(i + 1, 0);
        c[0] = 2, c.back() = 1;
        a = FFT::mod_multiply(a, c);
        while (a.size() > m + 1) {
            a.pop_back();
        }
    }
    cout << a[m];
    return 0;
}

Submission Info

Submission Time
Task F - Knapsack for All Subsets
User st1vdy
Language C++ (GCC 9.2.1)
Score 600
Code Size 14993 Byte
Status AC
Exec Time 1541 ms
Memory 4908 KiB

Compile Error

./Main.cpp: In function ‘std::vector<long long int> FFT::multiply(std::vector<int>, std::vector<int>)’:
./Main.cpp:193:27: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::vector<int>::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
  193 |         for (int i = 0; i < a.size(); i++) tmp[i].real = a[i];
      |                         ~~^~~~~~~~~~
./Main.cpp:194:27: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::vector<int>::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
  194 |         for (int i = 0; i < b.size(); i++) tmp[i].imag = b[i];
      |                         ~~^~~~~~~~~~
./Main.cpp:200:27: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::vector<long long int>::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
  200 |         for (int i = 0; i < res.size(); i++) {
      |                         ~~^~~~~~~~~~~~
./Main.cpp: In function ‘std::tuple<int, int, bool> FFT::power_reduction(std::string, int)’:
./Main.cpp:338:27: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::__cxx11::basic_string<char>::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
  338 |         for (int i = 0; i < s.length(); i++) {
      |                         ~~^~~~~~~~~~~~
./Main.cpp: In function ‘int main()’:
./Main.cpp:414:25: warning: comparison of integer expressions of different signedness: ‘std::vector<int>::size_type’ {aka ‘long unsigned int’} and ‘int’ [-Wsign-compare]
  414 |         while (a.size() > m + 1) {
      |                ~~~~~~~~~^~~~~~~
./Main.cpp:402:15: warning: unused variable ‘res’ [-Wunused-variable]
  402 |     int n, m, res = 0;
      |               ^~~

Judge Result

Set Name sample All
Score / Max Score 0 / 0 600 / 600
Status
AC × 3
AC × 27
Set Name Test Cases
sample sample01, sample02, sample03
All 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 31, 32, 33, 34, 35, 41, 42, 43, 44, 45, 51, 52, 53, 54, sample01, sample02, sample03
Case Name Status Exec Time Memory
11 AC 12 ms 4240 KiB
12 AC 20 ms 4096 KiB
13 AC 32 ms 4136 KiB
14 AC 16 ms 4088 KiB
15 AC 9 ms 4048 KiB
21 AC 1065 ms 3752 KiB
22 AC 1065 ms 4848 KiB
23 AC 1070 ms 4788 KiB
24 AC 1373 ms 4908 KiB
25 AC 234 ms 4752 KiB
31 AC 36 ms 3528 KiB
32 AC 46 ms 3700 KiB
33 AC 43 ms 3548 KiB
34 AC 44 ms 3552 KiB
35 AC 44 ms 3672 KiB
41 AC 1541 ms 4332 KiB
42 AC 1388 ms 4304 KiB
43 AC 1518 ms 4256 KiB
44 AC 1524 ms 4328 KiB
45 AC 1528 ms 4260 KiB
51 AC 5 ms 4036 KiB
52 AC 9 ms 4536 KiB
53 AC 16 ms 3700 KiB
54 AC 2 ms 3540 KiB
sample01 AC 2 ms 3632 KiB
sample02 AC 2 ms 3500 KiB
sample03 AC 2 ms 3560 KiB