G - Takahashi And Pass-The-Ball Game Editorial by evima

別解

\(i\) について頂点 \(i\) から頂点 \(A_i\) へ辺が伸びているような有向グラフ (functional graph) を考えても解けます。各頂点 \(i\) について、\(i\) から辺を \(1, 2, \dots, K\) 回辿った先の延べ \(K\) 個の頂点それぞれに \(B_i/K\) を足すことになります。ここで、functional graph は何個かのサイクルの部分とそれらに属する頂点を根とする木の部分からなり、どの頂点からでも辺を何回か辿るとサイクルに到達することから、上記の「延べ \(K\) 個の頂点」の内訳はあまり複雑ではありません。頂点 \(i\) からサイクル上の頂点への加算と他の頂点への加算をそれぞれ累積和関連の技法を用いて一斉に行うことで(実装例を見てください)、問題の主要部分を線形時間で解けます。

実装例 (Python)

import sys

N, K = map(int, input().split())
sys.setrecursionlimit(2 * N + 99)
MOD = 998244353
A = list(map(lambda x: int(x) - 1, input().split()))
B = list(map(int, input().split()))
rA = [[] for _ in range(N)]
for i in range(N):
    rA[A[i]].append(i)
visited = [False for _ in range(N)]
ans = [0 for _ in range(N)]


def solve(r):
    m = len(r)
    s = set(r)
    a = [0 for _ in range(m + 1)]
    l = []
    loopsK, remK = K // m, K % m

    def dfs(v, pos):
        visited[v] = True
        p = (pos - len(l)) % m
        a[0] += loopsK * B[v] % MOD
        toK = p + remK
        if toK < m:
            a[p + 1] += B[v]
            a[toK + 1] -= B[v]
        else:
            a[0] += B[v]
            a[toK - m + 1] -= B[v]
            a[p + 1] += B[v]
        if len(l) > 1:
            loopsL, remL = min(K, len(l) - 1) // m, min(K, len(l) - 1) % m
            a[0] += loopsL * -B[v] % MOD
            toL = p + remL
            if toL < m:
                a[p + 1] += -B[v]
                a[toL + 1] -= -B[v]
            else:
                a[0] += -B[v]
                a[toL - m + 1] -= -B[v]
                a[p + 1] += -B[v]
        if len(l) > 1:
            ans[l[-1]] += B[v]
        if len(l) > K + 1:
            ans[l[-K - 1]] -= B[v]
        l.append(v)
        for w in rA[v]:
            if w not in s:
                dfs(w, pos)
                ans[v] += ans[w]
        l.pop()

    for i in range(m):
        dfs(r[i], i)

    for i in range(m):
        ans[r[i]] = a[i]
        a[i + 1] += a[i]


for i in range(N):
    if not visited[i]:
        cur = i
        l = []
        while not visited[cur]:
            visited[cur] = True
            l.append(cur)
            cur = A[cur]
        solve(l[l.index(cur):])
d = pow(K, MOD - 2, MOD)
print(' '.join(map(lambda x: str(x % MOD * d % MOD), ans)))

posted:
last update: