Official

C - ボーナスの分配 / Bonus Distribution Editorial by admin

DeepSeek V3

Overview

This problem asks us to distribute bonuses over \(D\) months to \(N\) employees based on their contribution levels, and find the product of all employees’ total bonuses modulo \(998244353\).

Analysis

By simplifying the formula in the problem, the total bonus received by employee \(i\) is \(T_i = \frac{P_i}{S} \times W_{total}\) (where \(W_{total} = \sum_{j=1}^{D} W_j\)). Therefore, the desired product \(R\) is:

\[R = \prod_{i=1}^{N} T_i = \prod_{i=1}^{N} \left( \frac{P_i \times W_{total}}{S} \right) = \frac{\left( \prod_{i=1}^{N} P_i \right) \times (W_{total})^N}{S^N}\]

When this expression is represented as an irreducible fraction \(\frac{A}{B}\), we need to compute \(A \times B^{-1} \mod 998244353\).

A naive implementation requires \(N\) multiplications to compute the numerator and denominator respectively, but since \(N\) can be up to \(10^6\), we can compute this efficiently by properly performing modular arithmetic.

Algorithm

  1. Compute the total contribution \(S\) and the total bonus fund \(W_{total}\)
  2. Compute the numerator as \(\prod_{i=1}^{N} (P_i \times W_{total}) \mod 998244353\)
  3. Compute the denominator as \(S^N \mod 998244353\)
  4. Multiply the numerator by the modular inverse of the denominator, and output the result modulo \(998244353\)

Complexity

  • Time complexity: \(O(N + D)\)
    • \(O(N + D)\) for reading input and computing sums
    • \(O(N + \log MOD)\) for multiplications and exponentiation
  • Space complexity: \(O(N + D)\)
    • Memory required to store the input data

Implementation Notes

  • Perform modular arithmetic at each step to prevent overflow during multiplication of large numbers

  • Use fast exponentiation (binary method) for computing the power of the denominator

  • Since the input data is large, reading all at once with sys.stdin.read() is efficient

  • The constraints guarantee that \(S\) is not a multiple of \(998244353\), so the modular inverse can be computed

    Source Code

MOD = 998244353

def main():
    import sys
    data = sys.stdin.read().split()
    it = iter(data)
    N = int(next(it)); D = int(next(it))
    P_list = [int(next(it)) for _ in range(N)]
    W_list = [int(next(it)) for _ in range(D)]
    
    total_S = sum(P_list)
    total_W = sum(W_list)
    
    S_mod = total_S % MOD
    inv_S_mod = pow(S_mod, MOD-2, MOD)
    
    numerator = 1
    denominator = 1
    for p in P_list:
        term = p * total_W
        numerator = numerator * term % MOD
        denominator = denominator * total_S % MOD
    
    R_mod = numerator * pow(denominator, MOD-2, MOD) % MOD
    print(R_mod)

if __name__ == "__main__":
    main()

This editorial was generated by deepseekv3.

posted:
last update: