F - Loud Cicada Editorial by en_translator
Implementation approach
For brevity, let “condition \(i\)” denote the condition of being a multiple of \(A_i\).
Consider a subset \(S \in \{1, \dots, N\}\) of the universal set consisting of the \(N\) conditions.
For each subset \(S\), define the “exact condition \(S\)” as follows: all conditions in \(S\) are met, while the others are violated.
If we can find, for every subset \(S\), the number of integers between \(1\) and \(Y\) that satisfies the exact condition of the subset, then the answer can be found by taking the sum of the count for all subsets \(S\) with \(|S| = M\).
Here, for each subset \(T\), define the “partial condition \(S\)” as follows: all conditions in \(S\) are met, while the others can be either satisfied or violated.
Partial conditions are easier to consider than exact conditions: a partial condition \(\{i_1, \dots, i_k\}\) can be simply rephrased as “being a multiple of the LCM (Least Common Multiples) of \(A_{i_1}, \dots, A_{i_k}\).” Using this LCM \(x\), the number of integers between \(1\) and \(Y\) satisfying this partial condition can be obtained as \(\lfloor \frac{Y}{x} \rfloor\). When finding the LCM, we can avoid the risk of overflows by, for example, representing integers greater than \(Y\) as \(Y+1\), and sticking to divisions instead of multiplications when checking conditions as much as possible.
Once you find the count for each partial condition, we can apply the inverse Zeta transform (called Mobius transform) to obtain the counts for each exact condition, thus yielding the answer. Simply put, this algorithm works as follows: we maintain the counts satisfying exact conditions for conditions \(1, \dots, k\) and partial conditions for \(k+1, \dots, N\), and transfer from partial to exact condition one by one.
Sample code (C++)
#include <iostream>
using std::cin;
using std::cout;
typedef long long int ll;
ll n, m, y;
ll a[21];
ll zeta[1LL << 21];
ll mygcd (ll l, ll r) {
if (r == 0) return l;
return mygcd(r, l % r);
}
int main (void) {
cin >> n >> m >> y;
for (ll i = 0; i < n; i++) {
cin >> a[i];
}
for (ll b = 0; !(b >> n); b++) {
ll prod = 1;
for (ll j = 0; j < n; j++) {
if (b & (1LL << j)) {
ll g = mygcd(prod, a[j]);
if ((prod / g) > y / a[j]) {
prod = y+1;
break;
} else {
prod = (prod / g) * a[j];
}
}
}
zeta[b] = y / prod;
}
// zeta[01] <- mobius(zeta[*1])
for (ll bi = 0; bi < n; bi++) {
for (ll b = 0; !(b >> n); b++) {
// [0] <- [*] - [1]
if (!(b & (1LL << bi))) {
zeta[b] = zeta[b] - zeta[b | (1LL << bi)];
}
}
}
ll ans[n+1];
for (ll i = 0; i <= n; i++) {
ans[i] = 0;
}
for (ll b = 0; !(b >> n); b++) {
ll cnt = 0;
for (ll j = 0; j < n; j++) {
if (b & (1LL << j)) {
cnt++;
}
}
ans[cnt] += zeta[b];
}
cout << ans[m] << "\n";
return 0;
}
posted:
last update: