Submission #15794827
Source Code Expand
#include <iostream>
#include <algorithm>
#include <vector>
#include <cmath>
#include <complex>
template <int MOD>
struct ModInt {
using lint = long long;
int val;
// constructor
ModInt(lint v = 0) : val(v % MOD) {
if (val < 0) val += MOD;
};
// unary operator
ModInt operator+() const { return ModInt(val); }
ModInt operator-() const { return ModInt(MOD - val); }
ModInt inv() const { return this->pow(MOD - 2); }
// arithmetic
ModInt operator+(const ModInt& x) const { return ModInt(*this) += x; }
ModInt operator-(const ModInt& x) const { return ModInt(*this) -= x; }
ModInt operator*(const ModInt& x) const { return ModInt(*this) *= x; }
ModInt operator/(const ModInt& x) const { return ModInt(*this) /= x; }
ModInt pow(lint n) const {
auto x = ModInt(1);
auto b = *this;
while (n > 0) {
if (n & 1) x *= b;
n >>= 1;
b *= b;
}
return x;
}
// compound assignment
ModInt& operator+=(const ModInt& x) {
if ((val += x.val) >= MOD) val -= MOD;
return *this;
}
ModInt& operator-=(const ModInt& x) {
if ((val -= x.val) < 0) val += MOD;
return *this;
}
ModInt& operator*=(const ModInt& x) {
val = lint(val) * x.val % MOD;
return *this;
}
ModInt& operator/=(const ModInt& x) { return *this *= x.inv(); }
// compare
bool operator==(const ModInt& b) const { return val == b.val; }
bool operator!=(const ModInt& b) const { return val != b.val; }
bool operator<(const ModInt& b) const { return val < b.val; }
bool operator<=(const ModInt& b) const { return val <= b.val; }
bool operator>(const ModInt& b) const { return val > b.val; }
bool operator>=(const ModInt& b) const { return val >= b.val; }
// I/O
friend std::istream& operator>>(std::istream& is, ModInt& x) noexcept {
lint v;
is >> v;
x = v;
return is;
}
friend std::ostream& operator<<(std::ostream& os, const ModInt& x) noexcept { return os << x.val; }
};
template <int K>
struct FastFourierTransform {
using cplx = std::complex<double>;
using cplxs = std::vector<cplx>;
static constexpr double PI = 3.14159265358979323846L;
cplxs zetas;
explicit FastFourierTransform() : zetas(K) {
for (int i = 0; i < K; ++i) {
zetas[i] = std::polar(1., PI * 2 / (1 << i));
}
}
void bitrev(cplxs& f) const {
int n = f.size();
for (int i = 0; i < n; ++i) {
int ti = i, ni = 0;
for (int k = 0; (1 << k) < n; ++k) {
int b = (ti & 1);
ti >>= 1;
ni <<= 1;
ni += b;
}
if (i < ni) {
std::swap(f[i], f[ni]);
}
}
}
void udft(cplxs& f, bool isinv) const {
if (f.size() <= 1) return;
int l = 1;
int k = 1 << l;
int n = f.size();
while (k <= n) {
auto zeta = zetas[l];
if (isinv) zeta = std::conj(zeta);
for (int r = 0; r < n / k; ++r) {
cplx zetapow = 1;
for (int j = 0; j < k / 2; ++j) {
int b = r * k + j;
auto t = zetapow * f[b + k / 2];
f[b + k / 2] = f[b] - t;
f[b] = f[b] + t;
zetapow *= zeta;
}
}
++l;
k <<= 1;
}
}
void dft(cplxs& f, bool isinv) const {
bitrev(f);
udft(f, isinv);
}
// main routine
using lint = long long;
using lints = std::vector<lint>;
lints fft(const lints& ff, const lints& gf) const {
auto f = li2cp(ff),
g = li2cp(gf);
int fdeg = f.size(),
gdeg = g.size();
int k = 0;
while ((1 << k) < fdeg + gdeg) ++k;
int n = (1 << k);
f.resize(n, 0);
g.resize(n, 0);
dft(f, false);
dft(g, false);
cplxs h(n);
for (int i = 0; i < n; ++i) h[i] = f[i] * g[i];
dft(h, true);
h.resize(fdeg + gdeg - 1);
for (auto& x : h) x /= n;
return cp2li(h);
}
// lint <-> complex converter
cplxs li2cp(const lints& f) const {
cplxs ret;
std::transform(f.begin(), f.end(), std::back_inserter(ret),
[](auto x) { return cplx(x); });
return ret;
}
lints cp2li(const cplxs& f) const {
lints ret;
std::transform(f.begin(), f.end(), std::back_inserter(ret),
[](auto x) { return std::llround(x.real()); });
return ret;
}
};
constexpr int MOD = 200003;
using mint = ModInt<MOD>;
const FastFourierTransform<20> FFT;
using lint = long long;
void solve() {
// g^log[x] = x
std::vector<int> log(MOD);
mint g = 2;
for (int i = 0; i < MOD - 1; ++i) {
log[g.pow(i).val] = i;
}
std::vector<lint> cnt(MOD - 1, 0);
int n;
std::cin >> n;
while (n--) {
int a;
std::cin >> a;
if (a != 0) ++cnt[log[a]];
}
// FFTで畳み込む
auto res = FFT.fft(cnt, cnt);
lint ans = 0;
for (int i = 0; i < (int)res.size(); ++i) {
lint num = llround(res[i]);
ans += g.pow(i).val * num;
}
// a_i * a_iを省く
for (int i = 0; i < (int)cnt.size(); ++i) {
lint num = llround(cnt[i]);
ans -= g.pow(i * 2).val * num;
}
// i > jを省く
ans /= 2;
std::cout << ans << "\n";
}
int main() {
std::cin.tie(nullptr);
std::ios::sync_with_stdio(false);
solve();
return 0;
}
Submission Info
Submission Time |
|
Task |
C - Product Modulo |
User |
Tiramister |
Language |
C++ (GCC 9.2.1) |
Score |
800 |
Code Size |
5768 Byte |
Status |
AC |
Exec Time |
233 ms |
Memory |
37632 KiB |
Judge Result
Set Name |
Sample |
All |
Score / Max Score |
0 / 0 |
800 / 800 |
Status |
|
|
Set Name |
Test Cases |
Sample |
s1.txt, s2.txt |
All |
001.txt, 002.txt, 003.txt, 004.txt, 005.txt, 006.txt, 007.txt, 008.txt, 009.txt, 010.txt, 011.txt, 012.txt, 013.txt, s1.txt, s2.txt |
Case Name |
Status |
Exec Time |
Memory |
001.txt |
AC |
222 ms |
37632 KiB |
002.txt |
AC |
225 ms |
37500 KiB |
003.txt |
AC |
218 ms |
37632 KiB |
004.txt |
AC |
225 ms |
37540 KiB |
005.txt |
AC |
232 ms |
37492 KiB |
006.txt |
AC |
231 ms |
37540 KiB |
007.txt |
AC |
230 ms |
37496 KiB |
008.txt |
AC |
228 ms |
37556 KiB |
009.txt |
AC |
231 ms |
37536 KiB |
010.txt |
AC |
233 ms |
37552 KiB |
011.txt |
AC |
228 ms |
37544 KiB |
012.txt |
AC |
229 ms |
37540 KiB |
013.txt |
AC |
214 ms |
37520 KiB |
s1.txt |
AC |
216 ms |
37416 KiB |
s2.txt |
AC |
223 ms |
37524 KiB |