Official

G - Minimum Steiner Tree 2 Editorial by hirayuu_At


以下のようなDPをします。

  • \(DP[x][y]=\) 集合 \(x\) に含まれる頂点と頂点 \(y\) を含む良いグラフの最小コスト

すると、各クエリで求めたいものは \(DP[\lbrace 1,2,\dots,K,s_i\rbrace][t_i]\) と表すことができます。これらすべてを前計算することを考えます。

\(DP[x][y]\) は、以下のように計算できます。

  • \(x=\emptyset\) または \(x=\lbrace y\rbrace\) のとき、\(DP[x][y]=0\)
  • \(\displaystyle DP[x][y]=\min(\min_{z\subset x}DP[z][y]+DP[x\setminus z][y],\min_{1\leq i\leq N}DP[x][i]+C_{i,y})\)

後者の遷移は循環していることに注意が必要です。自身を参照する遷移を行わないこと、\(x\) を固定して \(\min\) の左側の遷移を先に行い、右側の遷移はダイクストラ法の要領で行うことで回避できます。グラフが密なので逐一最小値の頂点を探す \(O(N^2)\) のダイクストラ法の方が高速です。

\(DP[x][*]\) をすべて計算する計算量は \(O(3^{|x|}N+2^{|x|}N^2)\) です。\(x\) としては \(\lbrace 1,2,\dots,K\rbrace\) に高々 \(1\) 要素追加したもののみを考えればよいので、 全体として \(O(3^KN^2+2^KN^3)\) でこの問題を解くことができました。高速な言語で定数倍が悪くなければ \(O(3^KN^2+2^KN^3\log N)\) なども通ると思います。

実装例 (PyPy3)

INF = 1 << 60

N, K = map(int, input().split())
C = [list(map(int, input().split())) for _ in range(N)]

ans = [[0] * N for _ in range(N)]
dp = [[INF] * N for _ in range(1 << K)]
dp[0] = [0] * N

for i in range(K):
    dp[1 << i][i] = 0

for i in range(1, 1 << K):
    bi = i
    while True:
        for j in range(N):
            dp[i][j] = min(dp[i][j], dp[bi][j] + dp[i ^ bi][j])
        if bi == 0:
            break
        bi = (bi - 1) & i

    used = [False] * N
    for _ in range(N):
        mi = (INF, -1)
        for v in range(N):
            if not used[v]:
                mi = min(mi, (dp[i][v], v))
        _, v = mi
        used[v] = True
        for u in range(N):
            dp[i][u] = min(dp[i][u], dp[i][v] + C[v][u])

for s in range(K, N):
    ndp = [[INF] * N for _ in range(1 << K)]
    ndp[0][s] = 0

    for i in range(1 << K):
        bi = i
        while True:
            for j in range(N):
                ndp[i][j] = min(ndp[i][j], ndp[bi][j] + dp[i ^ bi][j])
            if bi == 0:
                break
            bi = (bi - 1) & i

        used = [False] * N
        for _ in range(N):
            mi = (INF, -1)
            for v in range(N):
                if not used[v]:
                    mi = min(mi, (ndp[i][v], v))
            _, v = mi
            used[v] = True
            for u in range(N):
                ndp[i][u] = min(ndp[i][u], ndp[i][v] + C[v][u])

    ans[s] = ndp[-1][:]

Q = int(input())
for _ in range(Q):
    s, t = map(int, input().split())
    print(ans[s - 1][t - 1])

posted:
last update: