Official

D - Fraction Line Editorial by evima


Hints: https://atcoder.jp/contests/arc192/editorial/12197


Let \(p_i\) denote the \(i\)‑th smallest prime number.

Also, let \(d_{y}(x)\) denote the largest exponent such that \(y^{\,d_{y}(x)}\) divides \(x\).

One can show that for each \(i\), \(d_{p_i}\left(f\left(\dfrac{x}{y}\right)\right)=\left|d_{p_i}(x)-d_{p_i}(y)\right|\). This shows that we can process each prime factor independently. Since the answers for individual prime factors can be multiplied together to obtain the final answer, we consider each prime separately.

Consider the answer for a prime factor \(p_i\). Let \(s_i=(d_{p_i}(S_1),d_{p_i}(S_2),\dots,d_{p_i}(S_N))\), and it must satisfy both of the following:

  • \(\min(s_i)=0\)
  • \(\left|s_{i,j}-s_{i,j+1}\right|=d_{p_i}\left(A_i\right)\) for each \(j\) such that \(1\leq j\lt N\)

We need to find the sum of \(p_i^{\text{sum}(s_i)}\) over all such \(s_i\), which can be done by a DP where we maintain the current \(s_{i,j}\) and a flag indicating whether \(\min(S)=0\). For the transitions, see Sample Implementation.

Let us estimate the complexity. \(\sum_{i=1}^{\infty}\sum_{j=1}^{N}d_{p_i}(S_j)\) is at most \(O(N\log\max(A))\), so the size of the DP table is \(O(N^2\log\max(A))\). Each transition takes \(O(1)\) time, so including the cost of prime factorization, the complexity is \(O(N\sqrt{\max(A)}+N^2\log\max(A))\), which is fast enough. Even an \(O(N\max(A))\) factorization should pass.

Sample Implementation (PyPy3)

MOD = 998244353
N = int(input())
A = list(map(int, input().split()))
ans = 1

for i in range(2, max(A) + 1):
    a = [0] * (N - 1)
    for j in range(N - 1):
        while A[j] % i == 0:
            a[j] += 1
            A[j] //= i

    sa = sum(a)

    if sa == 0:
        continue

    pows = [1] * (sa + 1)

    for j in range(1, sa + 1):
        pows[j] = (pows[j - 1] * i) % MOD

    dp = [[pows[j], 0] for j in range(sa + 1)]
    dp[0] = [0, 1]

    for j in range(N - 1):
        ndp = [[0, 0] for _ in range(sa + 1)]
        for k in range(sa + 1):
            if k - a[j] >= 0:
                nxt = 0
                if k - a[j] == 0:
                    nxt = 1

                ndp[k - a[j]][nxt] += dp[k][0] * pows[k - a[j]]
                ndp[k - a[j]][nxt] %= MOD

                ndp[k - a[j]][1] += dp[k][1] * pows[k - a[j]]
                ndp[k - a[j]][1] %= MOD

            if a[j] != 0 and k + a[j] <= sa:
                ndp[k + a[j]][0] += dp[k][0] * pows[k + a[j]]
                ndp[k + a[j]][0] %= MOD

                ndp[k + a[j]][1] += dp[k][1] * pows[k + a[j]]
                ndp[k + a[j]][1] %= MOD

        dp = ndp

    nsum = 0
    for j in range(sa + 1):
        nsum = (nsum + dp[j][1]) % MOD

    ans = (ans * nsum) % MOD

print(ans)

posted:
last update: