Official

G - Minimum Steiner Tree 2 Editorial by en_translator


Consider the following DP (Dynamic Programming):

  • \(DP[x][y]=\) the minimum cost of a graph containing all the vertices in the set \(x\), and vertex \(y\).

Then what we want to find for each query can be represented as \(DP[\lbrace 1,2,\dots,K,s_i\rbrace][t_i]\). Can we precalculate all of them?

\(DP[x][y]\) can be computed as follows:

  • \(DP[x][y]=0\) if \(x=\emptyset\) or \(x=\lbrace y\rbrace\).
  • \(\displaystyle DP[x][y]=\min(\min_{z\subset x}DP[z][y]+DP[x\setminus z][y],\min_{1\leq i\leq N}DP[x][i]+C_{i,y})\).

Notice that the latter transition depends on itself. The issue can be circumvented by not performing transitions referring to itself, and by first processing the transitions for the first argument of \(\min\) first, and then handling the second-argument transition in the manner of Dijkstra’s algorithm. Since the graph is dense, the \(O(N^2)\) Dijkstra’s algorithm, where we find the next nearest vertex by scanning all the vertices every step, is faster.

The complexity of evaluating all \(DP[x][*]\) is \(O(3^{|x|}N+2^{|x|}N^2)\). The \(x\) ranges over the sets obtained by at most one element to \(\lbrace 1,2,\dots,K\rbrace\), so the problem has been solved in a total of \(O(3^KN^2+2^KN^3)\) time. If the constant factor is not bad and your language is fast, an \(O(3^KN^2+2^KN^3\log N)\) solution may also pass.

Sample code (PyPy3)

INF = 1 << 60

N, K = map(int, input().split())
C = [list(map(int, input().split())) for _ in range(N)]

ans = [[0] * N for _ in range(N)]
dp = [[INF] * N for _ in range(1 << K)]
dp[0] = [0] * N

for i in range(K):
    dp[1 << i][i] = 0

for i in range(1, 1 << K):
    bi = i
    while True:
        for j in range(N):
            dp[i][j] = min(dp[i][j], dp[bi][j] + dp[i ^ bi][j])
        if bi == 0:
            break
        bi = (bi - 1) & i

    used = [False] * N
    for _ in range(N):
        mi = (INF, -1)
        for v in range(N):
            if not used[v]:
                mi = min(mi, (dp[i][v], v))
        _, v = mi
        used[v] = True
        for u in range(N):
            dp[i][u] = min(dp[i][u], dp[i][v] + C[v][u])

for s in range(K, N):
    ndp = [[INF] * N for _ in range(1 << K)]
    ndp[0][s] = 0

    for i in range(1 << K):
        bi = i
        while True:
            for j in range(N):
                ndp[i][j] = min(ndp[i][j], ndp[bi][j] + dp[i ^ bi][j])
            if bi == 0:
                break
            bi = (bi - 1) & i

        used = [False] * N
        for _ in range(N):
            mi = (INF, -1)
            for v in range(N):
                if not used[v]:
                    mi = min(mi, (ndp[i][v], v))
            _, v = mi
            used[v] = True
            for u in range(N):
                ndp[i][u] = min(ndp[i][u], ndp[i][v] + C[v][u])

    ans[s] = ndp[-1][:]

Q = int(input())
for _ in range(Q):
    s, t = map(int, input().split())
    print(ans[s - 1][t - 1])

posted:
last update: