Submission #25073462
Source Code Expand
#define _SILENCE_CXX17_C_HEADER_DEPRECATION_WARNING
#define _CRT_SECURE_NO_WARNINGS
#include <bits/stdc++.h>
using namespace std;
#define db double
#ifndef ONLINE_JUDGE
inline int __builtin_clz(int v) { // 返回前导0的个数
return __lzcnt(v);
}
inline int __builtin_ctz(int v) { // 返回末尾0的个数
if (v == 0) {
return 0;
}
__asm {
bsf eax, dword ptr[v];
}
}
inline int __builtin_popcount(int v) { // 返回二进制中1的个数
return __popcnt(v);
}
#endif
struct Complex {
db real, imag;
Complex(db x = 0, db y = 0) :real(x), imag(y) {}
Complex& operator+=(const Complex& rhs) {
real += rhs.real; imag += rhs.imag;
return *this;
}
Complex& operator-=(const Complex& rhs) {
real -= rhs.real; imag -= rhs.imag;
return *this;
}
Complex& operator*=(const Complex& rhs) {
db t_real = real * rhs.real - imag * rhs.imag;
imag = real * rhs.imag + imag * rhs.real;
real = t_real;
return *this;
}
Complex& operator/=(double x) {
real /= x, imag /= x;
return *this;
}
friend Complex operator + (const Complex& a, const Complex& b) { return Complex(a) += b; }
friend Complex operator - (const Complex& a, const Complex& b) { return Complex(a) -= b; }
friend Complex operator * (const Complex& a, const Complex& b) { return Complex(a) *= b; }
friend Complex operator / (const Complex& a, const db& b) { return Complex(a) /= b; }
Complex power(long long p) const {
assert(p >= 0);
Complex a = *this, res = { 1,0 };
while (p > 0) {
if (p & 1) res = res * a;
a = a * a;
p >>= 1;
}
return res;
}
static long long val(double x) { return x < 0 ? x - 0.5 : x + 0.5; }
inline long long Real() const { return val(real); }
inline long long Imag() const { return val(imag); }
Complex conj()const { return Complex(real, -imag); }
explicit operator int()const { return Real(); }
friend ostream& operator<<(ostream& stream, const Complex& m) {
return stream << complex<db>(m.real, m.imag);
}
};
constexpr int MOD = 998244353;
constexpr int Phi_MOD = 998244352;
inline int exgcd(int a, int md = MOD) {
a %= md;
if (a < 0) a += md;
int b = md, u = 0, v = 1;
while (a) {
int t = b / a;
b -= t * a; swap(a, b);
u -= t * v; swap(u, v);
}
assert(b == 1);
if (u < 0) u += md;
return u;
}
inline int add(int a, int b) { return a + b >= MOD ? a + b - MOD : a + b; }
inline int sub(int a, int b) { return a - b < 0 ? a - b + MOD : a - b; }
inline int mul(int a, int b) { return 1LL * a * b % MOD; }
inline int powmod(int a, long long b) {
int res = 1;
while (b > 0) {
if (b & 1) res = mul(res, a);
a = mul(a, a);
b >>= 1;
}
return res;
}
vector<int> inv, fac, ifac;
void prepare_factorials(int maximum) {
inv.assign(maximum + 1, 1);
// Make sure MOD is prime, which is necessary for the inverse algorithm below.
for (int p = 2; p * p <= MOD; p++)
assert(MOD % p != 0);
for (int i = 2; i <= maximum; i++)
inv[i] = mul(inv[MOD % i], (MOD - MOD / i));
fac.resize(maximum + 1);
ifac.resize(maximum + 1);
fac[0] = ifac[0] = 1;
for (int i = 1; i <= maximum; i++) {
fac[i] = mul(i, fac[i - 1]);
ifac[i] = mul(inv[i], ifac[i - 1]);
}
}
namespace FFT {
vector<Complex> roots = { Complex(0, 0), Complex(1, 0) };
vector<int> bit_reverse;
int max_size = 1 << 20;
const long double pi = acosl(-1.0l);
constexpr int FFT_CUTOFF = 150;
inline bool is_power_of_two(int n) { return (n & (n - 1)) == 0; }
inline int round_up_power_two(int n) {
assert(n > 0);
while (n & (n - 1)) {
n = (n | (n - 1)) + 1;
}
return n;
}
// Given n (a power of two), finds k such that n == 1 << k.
inline int get_length(int n) {
assert(is_power_of_two(n));
return __builtin_ctz(n);
}
// Rearranges the indices to be sorted by lowest bit first, then second lowest, etc., rather than highest bit first.
// This makes even-odd div-conquer much easier.
void bit_reorder(int n, vector<Complex>& values) {
if ((int)bit_reverse.size() != n) {
bit_reverse.assign(n, 0);
int length = get_length(n);
for (int i = 0; i < n; i++) {
bit_reverse[i] = (bit_reverse[i >> 1] >> 1) + ((i & 1) << (length - 1));
}
}
for (int i = 0; i < n; i++) {
if (i < bit_reverse[i]) {
swap(values[i], values[bit_reverse[i]]);
}
}
}
void prepare_roots(int n) {
assert(n <= max_size);
if ((int)roots.size() >= n)
return;
int length = get_length(roots.size());
roots.resize(n);
// The roots array is set up such that for a given power of two n >= 2, roots[n / 2] through roots[n - 1] are
// the first half of the n-th primitive roots of MOD.
while (1 << length < n) {
for (int i = 1 << (length - 1); i < 1 << length; i++) {
roots[2 * i] = roots[i];
long double angle = pi * (2 * i + 1) / (1 << length);
roots[2 * i + 1] = Complex(-cos(angle), -sin(angle));
}
length++;
}
}
void fft_iterative(int N, vector<Complex>& values) {
assert(is_power_of_two(N));
prepare_roots(N);
bit_reorder(N, values);
for (int n = 1; n < N; n *= 2) {
for (int start = 0; start < N; start += 2 * n) {
for (int i = 0; i < n; i++) {
Complex& even = values[start + i];
Complex odd = values[start + n + i] * roots[n + i];
values[start + n + i] = even - odd;
values[start + i] = even + odd;
}
}
}
}
vector<long long> multiply(vector<int> a, vector<int> b) { // 普通FFT
int n = a.size();
int m = b.size();
if (min(n, m) < FFT_CUTOFF) {
vector<long long> res(n + m - 1);
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
res[i + j] += 1LL * a[i] * b[j];
}
}
return res;
}
int N = round_up_power_two(n + m - 1);
vector<Complex> tmp(N);
for (int i = 0; i < a.size(); i++) tmp[i].real = a[i];
for (int i = 0; i < b.size(); i++) tmp[i].imag = b[i];
fft_iterative(N, tmp);
for (int i = 0; i < N; i++) tmp[i] = tmp[i] * tmp[i];
reverse(tmp.begin() + 1, tmp.end());
fft_iterative(N, tmp);
vector<long long> res(n + m - 1);
for (int i = 0; i < res.size(); i++) {
res[i] = tmp[i].imag / 2 / N + 0.5;
}
return res;
}
vector<int> mod_multiply(vector<int> a, vector<int> b, int lim = max_size) { // 任意模数FFT
int n = a.size();
int m = b.size();
if (min(n, m) < FFT_CUTOFF) {
vector<int> res(n + m - 1);
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
res[i + j] += 1LL * a[i] * b[j] % MOD;
res[i + j] %= MOD;
}
}
return res;
}
int N = round_up_power_two(n + m - 1);
N = min(N, lim);
vector<Complex> P(N);
vector<Complex> Q(N);
for (int i = 0; i < n; i++) {
P[i] = Complex(a[i] >> 15, a[i] & 0x7fff);
}
for (int i = 0; i < m; i++) {
Q[i] = Complex(b[i] >> 15, b[i] & 0x7fff);
}
fft_iterative(N, P);
fft_iterative(N, Q);
vector<Complex>A(N), B(N), C(N), D(N);
for (int i = 0; i < N; i++) {
Complex P2 = P[(N - i) & (N - 1)].conj();
A[i] = (P2 + P[i]) * Complex(0.5, 0),
B[i] = (P2 - P[i]) * Complex(0, 0.5);
Complex Q2 = Q[(N - i) & (N - 1)].conj();
C[i] = (Q2 + Q[i]) * Complex(0.5, 0),
D[i] = (Q2 - Q[i]) * Complex(0, 0.5);
}
for (int i = 0; i < N; i++) {
P[i] = (A[i] * C[i]) + (B[i] * D[i]) * Complex(0, 1),
Q[i] = (A[i] * D[i]) + (B[i] * C[i]) * Complex(0, 1);
}
reverse(P.begin() + 1, P.end());
reverse(Q.begin() + 1, Q.end());
fft_iterative(N, P);
fft_iterative(N, Q);
for (int i = 0; i < N; i++) {
P[i] /= N, Q[i] /= N;
}
int size = min(n + m - 1, lim);
vector<int> res(size);
for (int i = 0; i < size; i++) {
long long ac = P[i].Real() % MOD, bd = P[i].Imag() % MOD,
ad = Q[i].Real() % MOD, bc = Q[i].Imag() % MOD;
res[i] = ((ac << 30) + bd + ((ad + bc) << 15)) % MOD;
}
return res.resize(n + m - 1), res;
}
vector<int> mod_inv(vector<int> a) { // 多项式逆
int n = a.size();
int N = round_up_power_two(a.size());
a.resize(N * 2);
vector<int> res(1);
res[0] = exgcd(a[0]);
for (int i = 2; i <= N; i <<= 1) {
vector<int> tmp(a.begin(), a.begin() + i);
int n = (i << 1);
tmp = mod_multiply(tmp, mod_multiply(res, res, n), n);
res.resize(i);
for (int j = 0; j < i; j++) {
res[j] = add(res[j], sub(res[j], tmp[j]));
}
}
res.resize(n);
return res;
}
vector<int> integral(vector<int> a) { // 多项式积分
assert(a.size() <= inv.size());
a.push_back(0);
for (int i = (int)a.size() - 1; i >= 1; i--) {
a[i] = mul(a[i - 1], inv[i]);
}
return a;
}
vector<int> differential(vector<int> a) { // 多项式求导
for (int i = 0; i < (int)a.size() - 1; i++) {
a[i] = mul(i + 1, a[i + 1]);
}
a.pop_back();
return a;
}
vector<int> ln(vector<int> a) { // 多项式对数函数
assert((int)a[0] == 1);
auto b = mod_multiply(differential(a), mod_inv(a));
b = integral(b);
b[0] = 0;
return b;
}
vector<int> exp(vector<int> a) { // 多项式指数函数
int N = round_up_power_two(a.size());
int n = a.size();
a.resize(N * 2);
vector<int> res{ 1 };
for (int i = 2; i <= N; i <<= 1) {
auto tmp = res;
tmp.resize(i);
tmp = ln(tmp);
for (int j = 0; j < i; j++) {
tmp[j] = sub(a[j], tmp[j]);
}
tmp[0] = add(tmp[0], 1);
res.resize(i);
res = mod_multiply(res, tmp, i << 1);
fill(res.begin() + i, res.end(), 0);
}
res.resize(n);
return res;
}
// Multiplies many polynomials whose total degree is n in O(n log^2 n).
vector<int> mod_multiply_all(const vector<vector<int>>& polynomials) {
if (polynomials.empty())
return { 1 };
struct compare_size {
bool operator()(const vector<int>& x, const vector<int>& y) {
return x.size() > y.size();
}
};
priority_queue<vector<int>, vector<vector<int>>, compare_size> pq(polynomials.begin(), polynomials.end());
while (pq.size() > 1) {
vector<int> a = pq.top(); pq.pop();
vector<int> b = pq.top(); pq.pop();
pq.push(mod_multiply(a, b));
}
return pq.top();
}
tuple<int, int, bool> power_reduction(string s, int n) { // 多项式快速幂预处理
int p = 0, q = 0; bool zero = false;
for (int i = 0; i < s.length(); i++) {
p = mul(p, 10);
p = add(p, s[i] - '0');
q = 1LL * q * 10 % Phi_MOD; // Phi_MOD 是MOD的欧拉函数值
q = (q + s[i] - '0');
if (q >= Phi_MOD) q -= Phi_MOD;
if (q >= (int)n) zero = true;
}
return { p,q,zero };
}
vector<int> power(vector<int> a, string s) { // 多项式快速幂 a^s O(nlogn)
int n = a.size();
auto [p, q, zero] = power_reduction(s, (int)a.size()); // 不需要降幂的话可以省去这部分
if (a[0] == 1) {
auto res = ln(a);
while ((int)res.size() > n) res.pop_back();
for (auto& i : res) {
i = mul(p, i);
}
res = exp(res);
return res;
} else {
int mn = -1;
vector<int> copy_a;
for (int i = 0; i < (int)a.size(); i++) {
if (a[i]) {
mn = i;
break;
}
}
if ((mn == -1) || (mn && (zero || (1LL * mn * p > n)))) { // a中所有元素都是0 或 偏移过大
return vector<int>(n, 0);
}
int inverse_amin = exgcd(a[mn]);
for (int i = mn; i < n; i++) {
copy_a.emplace_back(mul(a[i], inverse_amin));
}
copy_a = ln(copy_a);
while ((int)copy_a.size() > n) copy_a.pop_back();
for (auto& i : copy_a) {
i = mul(i, p);
}
copy_a = exp(copy_a);
vector<int> res(n, 0);
// shift是偏移量 power_k 是a_min^q(q是扩展欧拉定理降出来的幂次)
int shift = mn * p, power_k = powmod(a[mn], q);
for (int i = 0; i + shift < n; i++) {
res[i + shift] = mul(copy_a[i], power_k);
}
return res;
}
}
vector<long long> sub_convolution(vector<int> a, vector<int> b) { // 减法卷积 只保留非负次项
int n = b.size();
reverse(b.begin(), b.end());
auto res = multiply(a, b);
return vector<long long>(res.begin() + n - 1, res.end());
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int n, m, res = 0;
cin >> n >> m;
vector<int> a(m + 1, 0);
vector<int> b(n);
a[0] = 1;
for (auto& i : b) {
cin >> i;
}
for (auto i : b) {
vector<int> c(i + 1, 0);
c[0] = c.back() = 1;
a = FFT::mod_multiply(a, c);
while (a.size() > m + 1) {
a.pop_back();
}
res = add(res, a[m]);
a[0]++;
}
cout << res;
return 0;
}
Submission Info
Submission Time
2021-08-15 14:53:19+0900
Task
F - Knapsack for All Segments
User
st1vdy
Language
C++ (GCC 9.2.1)
Score
600
Code Size
15037 Byte
Status
AC
Exec Time
1515 ms
Memory
4920 KiB
Compile Error
./Main.cpp: In function ‘std::vector<long long int> FFT::multiply(std::vector<int>, std::vector<int>)’:
./Main.cpp:193:27: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::vector<int>::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
193 | for (int i = 0; i < a.size(); i++) tmp[i].real = a[i];
| ~~^~~~~~~~~~
./Main.cpp:194:27: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::vector<int>::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
194 | for (int i = 0; i < b.size(); i++) tmp[i].imag = b[i];
| ~~^~~~~~~~~~
./Main.cpp:200:27: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::vector<long long int>::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
200 | for (int i = 0; i < res.size(); i++) {
| ~~^~~~~~~~~~~~
./Main.cpp: In function ‘std::tuple<int, int, bool> FFT::power_reduction(std::string, int)’:
./Main.cpp:338:27: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::__cxx11::basic_string<char>::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
338 | for (int i = 0; i < s.length(); i++) {
| ~~^~~~~~~~~~~~
./Main.cpp: In function ‘int main()’:
./Main.cpp:414:25: warning: comparison of integer expressions of different signedness: ‘std::vector<int>::size_type’ {aka ‘long unsigned int’} and ‘int’ [-Wsign-compare]
414 | while (a.size() > m + 1) {
| ~~~~~~~~~^~~~~~~
Judge Result
Set Name
sample
All
Score / Max Score
0 / 0
600 / 600
Status
Set Name
Test Cases
sample
sample01, sample02, sample03
All
11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 31, 32, 33, 34, 35, 41, 42, 43, 44, 45, sample01, sample02, sample03
Case Name
Status
Exec Time
Memory
11
AC
16 ms
4248 KiB
12
AC
19 ms
4128 KiB
13
AC
27 ms
4188 KiB
14
AC
14 ms
4184 KiB
15
AC
7 ms
4072 KiB
21
AC
1184 ms
3736 KiB
22
AC
988 ms
4908 KiB
23
AC
1001 ms
4916 KiB
24
AC
1292 ms
4920 KiB
25
AC
214 ms
4768 KiB
31
AC
39 ms
3740 KiB
32
AC
51 ms
3676 KiB
33
AC
46 ms
3740 KiB
34
AC
45 ms
3716 KiB
35
AC
46 ms
3576 KiB
41
AC
1515 ms
4392 KiB
42
AC
1340 ms
4312 KiB
43
AC
1486 ms
4380 KiB
44
AC
1491 ms
4384 KiB
45
AC
1497 ms
4304 KiB
sample01
AC
2 ms
3596 KiB
sample02
AC
2 ms
3540 KiB
sample03
AC
2 ms
3536 KiB