D - Fraction Line Editorial by evima
Hints: https://atcoder.jp/contests/arc192/editorial/12197
Let \(p_i\) denote the \(i\)‑th smallest prime number.
Also, let \(d_{y}(x)\) denote the largest exponent such that \(y^{\,d_{y}(x)}\) divides \(x\).
One can show that for each \(i\), \(d_{p_i}\left(f\left(\dfrac{x}{y}\right)\right)=\left|d_{p_i}(x)-d_{p_i}(y)\right|\). This shows that we can process each prime factor independently. Since the answers for individual prime factors can be multiplied together to obtain the final answer, we consider each prime separately.
Consider the answer for a prime factor \(p_i\). Let \(s_i=(d_{p_i}(S_1),d_{p_i}(S_2),\dots,d_{p_i}(S_N))\), and it must satisfy both of the following:
- \(\min(s_i)=0\)
- \(\left|s_{i,j}-s_{i,j+1}\right|=d_{p_i}\left(A_i\right)\) for each \(j\) such that \(1\leq j\lt N\)
We need to find the sum of \(p_i^{\text{sum}(s_i)}\) over all such \(s_i\), which can be done by a DP where we maintain the current \(s_{i,j}\) and a flag indicating whether \(\min(S)=0\). For the transitions, see Sample Implementation.
Let us estimate the complexity. \(\sum_{i=1}^{\infty}\sum_{j=1}^{N}d_{p_i}(S_j)\) is at most \(O(N\log\max(A))\), so the size of the DP table is \(O(N^2\log\max(A))\). Each transition takes \(O(1)\) time, so including the cost of prime factorization, the complexity is \(O(N\sqrt{\max(A)}+N^2\log\max(A))\), which is fast enough. Even an \(O(N\max(A))\) factorization should pass.
MOD = 998244353
N = int(input())
A = list(map(int, input().split()))
ans = 1
for i in range(2, max(A) + 1):
a = [0] * (N - 1)
for j in range(N - 1):
while A[j] % i == 0:
a[j] += 1
A[j] //= i
sa = sum(a)
if sa == 0:
continue
pows = [1] * (sa + 1)
for j in range(1, sa + 1):
pows[j] = (pows[j - 1] * i) % MOD
dp = [[pows[j], 0] for j in range(sa + 1)]
dp[0] = [0, 1]
for j in range(N - 1):
ndp = [[0, 0] for _ in range(sa + 1)]
for k in range(sa + 1):
if k - a[j] >= 0:
nxt = 0
if k - a[j] == 0:
nxt = 1
ndp[k - a[j]][nxt] += dp[k][0] * pows[k - a[j]]
ndp[k - a[j]][nxt] %= MOD
ndp[k - a[j]][1] += dp[k][1] * pows[k - a[j]]
ndp[k - a[j]][1] %= MOD
if a[j] != 0 and k + a[j] <= sa:
ndp[k + a[j]][0] += dp[k][0] * pows[k + a[j]]
ndp[k + a[j]][0] %= MOD
ndp[k + a[j]][1] += dp[k][1] * pows[k + a[j]]
ndp[k + a[j]][1] %= MOD
dp = ndp
nsum = 0
for j in range(sa + 1):
nsum = (nsum + dp[j][1]) % MOD
ans = (ans * nsum) % MOD
print(ans)
posted:
last update: