C - Product Modulo Editorial by Errichto
Hints
Hint 1
There is something in math that makes multiplication modulo $P$ less chaotic. (yes, this is quite a mysterious hint)Hint 2
Use primitive root (generator) modulo $P$.Hint 3
Change every number $x$ to its position in sequence $g^0, g^1, g^2, \ldots$Hint 4
then FFTIdea
There is a way to reorder numbers \(1, 2, \ldots, P-1\) in such a way that we can quickly multiply two numbers and get their product without operators * or %.
Solution
\(g = 2\) is a primitive root modulo \(P\). The following equation changes the problem into convolution solvable with FFT:
\[g^i \cdot g^j = g^{i+j}\]
Compute the sequence \(g^0, g^1, \ldots, g^{P-2}\) and save the mapping from value to position which is just \(log_g(x)\).
While reading the input, for every non-zero \(x\) do cnt[position[x]]++. Compute convolution \(cnt \times cnt\) with FFT.
For every \(k\), you computed the number of pairs of elements \((x, y)\) in the input such that \(x = g^r\), \(y = g^t\) and \(k = r + t\) (so \(x \cdot y = g^k\)). This is the number of pairs of elements with exactly this product modulo \(P\). Well, \(g^k\) and \(g^{k+(P-1)}\) are the same number so the answer (for ordered pairs) is:
\[\sum_{k=0}^{P-2} (g^k \bmod P) \cdot (fftResult[k] + fftResult[k+(P-1)]) \]
To count only unordered pairs \((i < j)\), subtract products \((a_i \cdot a_i \bmod P)\) and at the end divide the answer by \(2\) (with modular inverse). Don’t forget about long longs.
code, 40 lines + FFT
#include // Product Modulo, by Errichto
using namespace std;
#define REP(i,n) for(int i = 0; i < int(n); ++i)
typedef double ld; // 'long double' is 2.2 times slower
struct C { ld real, imag;
C operator * (const C & he) const {
return C{real * he.real - imag * he.imag,
real * he.imag + imag * he.real};
}
void operator += (const C & he) { real += he.real; imag += he.imag; }
};
void dft(vector & a, bool rev) {
const int n = a.size();
for(int i = 1, k = 0; i < n; ++i) {
for(int bit = n / 2; (k ^= bit) < bit; bit /= 2);;;
if(i < k) swap(a[i], a[k]);
}
for(int len = 1, who = 0; len < n; len *= 2, ++who) {
static vector t[30];
vector & om = t[who];
if(om.empty()) {
om.resize(len);
const ld ang = 2 * acosl(0) / len;
REP(i, len) om[i] = i%2 || !who ?
C{cos(i*ang), sin(i*ang)} : t[who-1][i/2];
}
for(int i = 0; i < n; i += 2 * len)
REP(k, len) {
const C x = a[i+k], y = a[i+k+len]
* C{om[k].real, om[k].imag * (rev ? -1 : 1)};
a[i+k] += y;
a[i+k+len] = C{x.real - y.real, x.imag - y.imag};
}
}
if(rev) REP(i, n) a[i].real /= n;
}
templatevector multiply(const vector & a, const vector & b) {
if(a.empty() || b.empty()) return {};
int n = a.size() + b.size();
vector ans(n - 1);
/* if(min(a.size(),b.size()) < 190) { // BRUTE FORCE
REP(i, a.size()) REP(j, b.size()) ans[i+j] += a[i]*b[j];
return ans; } */
while(n&(n-1)) ++n;
auto speed = [&](const vector & w, int i, int k) {
int j = i ? n - i : 0, r = k ? -1 : 1;
return C{w[i].real + w[j].real * r, w[i].imag
- w[j].imag * r} * (k ? C{0, -0.5} : C{0.5, 0});
};
vector in(n), done(n);
REP(i, a.size()) in[i].real = a[i];
REP(i, b.size()) in[i].imag = b[i];
dft(in, false);
REP(i, n) done[i] = speed(in, i, 0) * speed(in, i, 1);
dft(done, true);
REP(i, ans.size()) ans[i] = is_integral::value ?
llround(done[i].real) : done[i].real;
//REP(i,ans.size())err=max(err,abs(done[i].real-ans[i]));
return ans;
}
const int P = 200003;
int main() {
int g = 2;
vector order{1};
for(int x = g; x != 1; x = (long long) x * g % P) {
order.push_back(x);
}
assert((int) order.size() == P - 1);
vector where(P);
for(int i = 0; i < (int) order.size(); ++i) {
where[order[i]] = i;
}
vector cnt(P);
int n;
scanf("%d", &n);
vector a(n);
for(int& x : a) {
scanf("%d", &x);
if(x != 0) {
cnt[where[x]]++;
}
}
vector multiplier(P);
vector answer = multiply(cnt, cnt);
for(int i = 0; i < (int) answer.size(); ++i) {
if(answer[i]) {
int product = order[i%order.size()];
multiplier[product] += answer[i];
}
}
for(int x : a) {
--multiplier[(long long) x * x % P];
}
long long total = 0;
for(int i = 1; i < P; ++i) {
assert(multiplier[i] % 2 == 0);
total += multiplier[i] / 2 * i;
}
printf("%lld\n", total);
}
Faster modulo multiplication?
With fast memory access, precomputations with primitive root might speed up the calculation of \(x \cdot y \bmod P\).
In C++, if you use a constant like const int P = 200003;, the modulo operations are anyway quite fast. But if you read \(P\) from the input, replacing x * y % P with g_powers[pos[x]+pos[y]] is actually 2.5 times faster on author’s machine (5.3s and 2s for 1e9 multiplications, \(P = 200\,003\)). I don’t know how good this is compared to other known algorithms for speeding up operation %P.
For values of \(P\) around \(10^9\), this is viable only locally to compute something once or if it’s an optimization contest. You need 8GB and a few seconds for preprocessing.
posted:
last update: