G - Random Subtraction 解説 by notonlysuccess

Formula Derivation Solution

中文版本

In the following text, the notation \(E[X]\) is used to represent the expectation of an event \(X\).

Let \(S_k​\) denote the sum of all elements when k numbers remain.

Let \(P_k = E\left[S_k^2\right]\) denote the expected value of the square of the sum of all elements when k numbers remain. From this, we have:

  • \(P_n = \left(\sum a_i\right)^2\), which can be calculated directly.

  • \(P_1\)​ is the final answer we need to find.

When transitioning from \(P_k \to P_{k-1}\)​, we choose two distinct numbers \(a_j, a_i\)​​ and replace them with \(a_j - a_i\). The change in the total sum is \(-2a_i\).

\[\begin{aligned} P_{k-1} &= E\left[(S_k - 2a_i) ^2\right] \\ &=E\left[S_k^2 - 4a_iS_k + 4a_i^2\right] \\ &= P_k - E\left[4\frac {S_k}kS_k\right] + E\left[4\frac {\sum a_i^2}{k}\right] \\ &= P_k(1 - \frac 4k)+E\left[\sum a_i^2\right]\frac 4k \end{aligned}\]

Let \(Q_k = E\left[\sum a_i^2\right]\) be the expected value of the sum of the squares of all elements when \(k\) numbers remain. We obtain:

\[P_{k-1} = P_k(1 - \frac 4k)+Q_k\frac 4k\]

Now, let’s derive the transition from \(Q_k \to Q_{k-1}\)​. When we choose two distinct numbers \(a_i, a_j\)​, we remove \(a_i^2,a_j^2\)​ from the sum of squares and add \((a_i-a_j)^2\). The net change is \(-2a_ia_j\):

\[\begin{aligned} Q_{k-1} &= Q_k - E[2a_ia_j] \\ &= Q_k - 2E\left[\frac {(\sum a_i)^2 - \sum a_i^2}{k(k-1)}\right] \\ &= Q_k - 2\frac {E[(\sum a_i)^2] - E[\sum a_i^2]}{k(k-1)} \\ &= Q_k - 2\frac {P_k - Q_k}{k(k-1)} \\ &= Q_k(1 + \frac 2 {k(k-1)}) - P_k\frac 2{k(k-1)} \end{aligned}\]

Using these recurrence relations, we can iteratively calculate from \(P_n, Q_n\)​ down to \(P_1, Q_1\)

Note: When deriving these formulas, unlike standard arithmetic operations, do not omit the E[…] operator or replace it with simple algebraic variables, as the expectations represent averaged states across all possible random choices.

int mint::mod = 998244353;
void solve() {
  int n;
  cin >> n;
  mint Q = 0, P = 0;
  for (int i = 0; i < n; i ++) {
    mint x;
    cin >> x;
    Q += x * x;
    P += x;
  }
  P *= P;
  for (int k = n; k >= 2; k --) {
    mint invk = mint(1) / k, inv2 = invk / (k - 1);
    tie(P, Q) = pair(
      P * (1 - 4 * invk) + Q * 4 * invk,
      Q * (1 + 2 * inv2) - P * 2 * inv2
    );
  }
  cout << P << endl;
}

投稿日時:
最終更新: