C - ボーナスの分配 / Bonus Distribution 解説 by admin
DeepSeek V3Overview
This problem asks us to distribute bonuses over \(D\) months to \(N\) employees based on their contribution levels, and find the product of all employees’ total bonuses modulo \(998244353\).
Analysis
By simplifying the formula in the problem, the total bonus received by employee \(i\) is \(T_i = \frac{P_i}{S} \times W_{total}\) (where \(W_{total} = \sum_{j=1}^{D} W_j\)). Therefore, the desired product \(R\) is:
\[R = \prod_{i=1}^{N} T_i = \prod_{i=1}^{N} \left( \frac{P_i \times W_{total}}{S} \right) = \frac{\left( \prod_{i=1}^{N} P_i \right) \times (W_{total})^N}{S^N}\]
When this expression is represented as an irreducible fraction \(\frac{A}{B}\), we need to compute \(A \times B^{-1} \mod 998244353\).
A naive implementation requires \(N\) multiplications to compute the numerator and denominator respectively, but since \(N\) can be up to \(10^6\), we can compute this efficiently by properly performing modular arithmetic.
Algorithm
- Compute the total contribution \(S\) and the total bonus fund \(W_{total}\)
- Compute the numerator as \(\prod_{i=1}^{N} (P_i \times W_{total}) \mod 998244353\)
- Compute the denominator as \(S^N \mod 998244353\)
- Multiply the numerator by the modular inverse of the denominator, and output the result modulo \(998244353\)
Complexity
- Time complexity: \(O(N + D)\)
- \(O(N + D)\) for reading input and computing sums
- \(O(N + \log MOD)\) for multiplications and exponentiation
- Space complexity: \(O(N + D)\)
- Memory required to store the input data
Implementation Notes
Perform modular arithmetic at each step to prevent overflow during multiplication of large numbers
Use fast exponentiation (binary method) for computing the power of the denominator
Since the input data is large, reading all at once with sys.stdin.read() is efficient
The constraints guarantee that \(S\) is not a multiple of \(998244353\), so the modular inverse can be computed
Source Code
MOD = 998244353
def main():
import sys
data = sys.stdin.read().split()
it = iter(data)
N = int(next(it)); D = int(next(it))
P_list = [int(next(it)) for _ in range(N)]
W_list = [int(next(it)) for _ in range(D)]
total_S = sum(P_list)
total_W = sum(W_list)
S_mod = total_S % MOD
inv_S_mod = pow(S_mod, MOD-2, MOD)
numerator = 1
denominator = 1
for p in P_list:
term = p * total_W
numerator = numerator * term % MOD
denominator = denominator * total_S % MOD
R_mod = numerator * pow(denominator, MOD-2, MOD) % MOD
print(R_mod)
if __name__ == "__main__":
main()
This editorial was generated by deepseekv3.
投稿日時:
最終更新: