Official

C - ボーナスの分配 / Bonus Distribution Editorial by admin

Claude 4.6 Opus (Thinking)

Overview

This problem asks us to compute the product \(R\) of the total bonuses received by \(N\) employees over \(D\) months, using modular arithmetic (mod \(998244353\)). By simplifying the formula, we can reduce it to straightforward computations of powers and products.

Analysis

Simplifying the formula is the key

The total bonus for each employee \(i\) is:

\[T_i = \frac{P_i}{S} \times \sum_{j=1}^{D} W_j\]

where \(S = \sum_{i=1}^{N} P_i\). What we want is the product of total bonuses across all employees:

\[R = \prod_{i=1}^{N} T_i = \prod_{i=1}^{N} \left( \frac{P_i}{S} \times \sum_{j=1}^{D} W_j \right)\]

In each \(T_i\), the terms \(\frac{1}{S}\) and \(\sum W_j\) are common factors independent of \(i\), so when taking the product of \(N\) terms:

\[R = \frac{\left(\prod_{i=1}^{N} P_i\right) \times \left(\sum_{j=1}^{D} W_j\right)^N}{S^N}\]

In this way, a seemingly complex problem reduces to computing three values:

  1. \(\prod_{i=1}^{N} P_i\): the product of all employees’ contribution values
  2. \(\left(\sum_{j=1}^{D} W_j\right)^N\): the sum of bonus funds raised to the \(N\)-th power
  3. \(S^N\): the sum of contribution values raised to the \(N\)-th power

What if we compute with exact rational numbers?

Since \(P_i\) and \(W_j\) can be up to \(10^9\) and \(N\) can be up to \(10^6\), \(\prod P_i\) would have an enormous number of digits, and computing directly with arbitrary-precision integers would result in TLE. The key is to perform all computations modulo \(998244353\).

Cases where \(R = 0\)

Since \(P_i \geq 1\), we have \(\prod P_i > 0\) and \(S > 0\). The only case where \(R = 0\) is when \(\sum W_j = 0\), that is, when all \(W_j = 0\) (since \(W_j \geq 0\)). We handle this case by checking for it first.

Algorithm

  1. Read the input.
  2. If \(\sum W_j = 0\) (all \(W_j = 0\)), output \(0\) and terminate.
  3. Compute the following all under \(\mod 998244353\):
    • \(\text{prodP} = \prod_{i=1}^{N} P_i \mod M\)
    • \(\text{sumW} = \sum_{j=1}^{D} W_j \mod M\)
    • \(S = \sum_{i=1}^{N} P_i \mod M\)
  4. Using Fermat’s little theorem:
    • \(\text{sumW}^N \mod M\) (via fast exponentiation)
    • The modular inverse of \(S^N\): \((S^N)^{-1} \equiv (S^N)^{M-2} \mod M\)
  5. The answer is \(\text{prodP} \times \text{sumW}^N \times (S^N)^{-1} \mod M\).

Concrete Example

For \(N=2, D=2, P=[1,2], W=[3,6]\): - \(S = 3\), \(\sum W = 9\) - \(T_1 = \frac{1}{3} \times 9 = 3\), \(T_2 = \frac{2}{3} \times 9 = 6\) - \(R = 3 \times 6 = 18\) - Verification with the formula: \(R = \frac{(1 \times 2) \times 9^2}{3^2} = \frac{2 \times 81}{9} = 18\)

Complexity

  • Time complexity: \(O(N + D + \log M)\) (computing products and sums is \(O(N+D)\), fast exponentiation is \(O(\log M)\))
  • Space complexity: \(O(N + D)\) (for storing the input)

Implementation Notes

  • Check \(\sum W_j = 0\) using actual values: Even if the value after taking mod is \(0\), the actual sum might not be \(0\). Since \(W_j \geq 0\), we check whether all values are \(0\).

  • It is guaranteed that \(S\) is not a multiple of \(998244353\), so the modular inverse of \(S^N\) always exists.

  • Python’s pow(a, b, MOD) internally uses fast exponentiation, so it efficiently computes powers and modular inverses.

  • To handle large inputs, we use sys.stdin.buffer.read() for bulk reading.

    Source Code

import sys

def main():
    input_data = sys.stdin.buffer.read().split()
    idx = 0
    N = int(input_data[idx]); idx += 1
    D = int(input_data[idx]); idx += 1
    
    MOD = 998244353
    
    P = [int(input_data[idx + i]) for i in range(N)]
    idx += N
    W = [int(input_data[idx + i]) for i in range(D)]
    idx += D
    
    # T_i = P_i / S * sum(W)
    # R = prod(T_i) = prod(P_i) * (sum(W))^N / S^N
    
    # If any W_j causes sum(W) = 0, then R = 0 if sum(W)=0
    # Actually if sum(W) = 0 then R = 0
    
    sumW = sum(W) % MOD
    
    # If sumW == 0 mod something... but we need to check if actual sum is 0
    # sum(W) can be 0 only if all W_j = 0, since W_j >= 0
    # Actually sum of non-negative integers is 0 iff all are 0
    # So check if all W are 0
    
    actual_sumW_zero = all(w == 0 for w in W)
    
    if actual_sumW_zero:
        print(0)
        return
    
    # Also if any P_i = 0... but constraint says P_i >= 1, so prod(P_i) > 0
    
    # R = prod(P_i) * sumW^N / S^N
    # Compute everything mod MOD
    
    # S = sum(P_i)
    S = 0
    for p in P:
        S = (S + p) % MOD
    
    # prod(P_i) mod MOD
    prodP = 1
    for p in P:
        prodP = prodP * (p % MOD) % MOD
    
    # sumW^N mod MOD
    sumW_N = pow(sumW, N, MOD)
    
    # S^N mod MOD
    S_N = pow(S, N, MOD)
    
    # S_N inverse
    S_N_inv = pow(S_N, MOD - 2, MOD)
    
    ans = prodP * sumW_N % MOD * S_N_inv % MOD
    
    print(ans)

main()

This editorial was generated by claude4.6opus-thinking.

posted:
last update: