Official

G - Socks 3 Editorial by en_translator


Let \(S=A_1+A_2+\dots+A_N\). First of all, by Pigeonhole Theorem, Takahashi draws a sock at most \(N+1\) times.

For \(i=1,\dots,N+1\), let \(P_i\) be the probability that he draws a sock at least \(i\) times. The sought expected value is \(P_1+P_2+\dots+P_{N+1}\).

Brief proof Let \(Q_i\) be the probability that he draws a sock exactly \(i\) times, then the sought expected value is \(1\cdot Q_1 + 2\cdot Q_2 + \dots + (N+1)\cdot Q_{N+1}\). Using \(Q_i=P_i-P_{i+1}\), one can rewrite it as \(P_1+P_2+\dots+P_{N+1}\).

\(P_{i+1}\) is equivalent to the probability that there is no duplicating color after drawing \(i\) socks. Thus, denoting by \(f_i\) the number of ways to choose \(i\) socks with different color among the \(S\) socks, \(P_{i+1}=\frac{f_i}{\binom{S}{i}}\). Finding \(\binom{S}{0},\binom{S}{1},\dots,\binom{S}{N}\) is trivial, so all that left is to find \(f_i\).

Define a polynomial \(F\) by \(\displaystyle F=\sum_{i=0}^{S} f_i x^i\). Since

\[\displaystyle f_i=\sum_{1\leq k_1 < k_2 < \dots < k_i\leq N} \prod_{l=1}^{i} A_{k_l},\]

\(F\) is the total product of \(F_1,F_2,\dots,F_N\), where \(F_i=(1+A_ix)\). Now it is boiled down to the following problem:

  • Given \(N\) linear polynomials, find their product.

This is a famous problem, which can be solved with a simple divide-and-conquer algorithm in a total of \(O(N\log^2 N)\) time. Specifically, define \(f(l,r)\) as the product of \(F_l,F_{l+1},\dots,F_{r-1}\), and apply \(f(l,r) = f(l,m)\times f(m,r)\ (m=\lfloor\frac{l+r}{2}\rfloor)\) to evaluate it recursively. By using NTT (Number-Theoretic Transform), the overall complexity becomes \(O(N\log^2 N)\).

Sample code (C++):

#include <bits/stdc++.h>
#include <atcoder/modint>
#include <atcoder/convolution>

using namespace std;
using namespace atcoder;

using mint = modint998244353;

// \prod_{i=l}^{r-1} (1+a[i]x)
vector<mint> prod(const vector<int> &a, int l, int r) {
    if (r - l == 1) return {1, a[l]};
    int m = (l + r) / 2;
    return convolution(prod(a, l, m), prod(a, m, r));
}

int main() {
    int n;
    cin >> n;
    vector<int> a(n);
    int s = 0;
    for (int &i: a) {
        cin >> i;
        s += i;
    }
    vector<mint> f = prod(a, 0, n);
    mint ans = 0;
    mint sCi = 1;
    for (int i = 0; i <= n; i++) {
        ans += f[i] / sCi;
        sCi *= s - i;
        sCi /= i + 1;
    }
    cout << ans.val() << endl;
}

posted:
last update: