G - Minimum Steiner Tree 2 Editorial by en_translator
Consider the following DP (Dynamic Programming):
- \(DP[x][y]=\) the minimum cost of a graph containing all the vertices in the set \(x\), and vertex \(y\).
Then what we want to find for each query can be represented as \(DP[\lbrace 1,2,\dots,K,s_i\rbrace][t_i]\). Can we precalculate all of them?
\(DP[x][y]\) can be computed as follows:
- \(DP[x][y]=0\) if \(x=\emptyset\) or \(x=\lbrace y\rbrace\).
- \(\displaystyle DP[x][y]=\min(\min_{z\subset x}DP[z][y]+DP[x\setminus z][y],\min_{1\leq i\leq N}DP[x][i]+C_{i,y})\).
Notice that the latter transition depends on itself. The issue can be circumvented by not performing transitions referring to itself, and by first processing the transitions for the first argument of \(\min\) first, and then handling the second-argument transition in the manner of Dijkstra’s algorithm. Since the graph is dense, the \(O(N^2)\) Dijkstra’s algorithm, where we find the next nearest vertex by scanning all the vertices every step, is faster.
The complexity of evaluating all \(DP[x][*]\) is \(O(3^{|x|}N+2^{|x|}N^2)\). The \(x\) ranges over the sets obtained by at most one element to \(\lbrace 1,2,\dots,K\rbrace\), so the problem has been solved in a total of \(O(3^KN^2+2^KN^3)\) time. If the constant factor is not bad and your language is fast, an \(O(3^KN^2+2^KN^3\log N)\) solution may also pass.
INF = 1 << 60
N, K = map(int, input().split())
C = [list(map(int, input().split())) for _ in range(N)]
ans = [[0] * N for _ in range(N)]
dp = [[INF] * N for _ in range(1 << K)]
dp[0] = [0] * N
for i in range(K):
dp[1 << i][i] = 0
for i in range(1, 1 << K):
bi = i
while True:
for j in range(N):
dp[i][j] = min(dp[i][j], dp[bi][j] + dp[i ^ bi][j])
if bi == 0:
break
bi = (bi - 1) & i
used = [False] * N
for _ in range(N):
mi = (INF, -1)
for v in range(N):
if not used[v]:
mi = min(mi, (dp[i][v], v))
_, v = mi
used[v] = True
for u in range(N):
dp[i][u] = min(dp[i][u], dp[i][v] + C[v][u])
for s in range(K, N):
ndp = [[INF] * N for _ in range(1 << K)]
ndp[0][s] = 0
for i in range(1 << K):
bi = i
while True:
for j in range(N):
ndp[i][j] = min(ndp[i][j], ndp[bi][j] + dp[i ^ bi][j])
if bi == 0:
break
bi = (bi - 1) & i
used = [False] * N
for _ in range(N):
mi = (INF, -1)
for v in range(N):
if not used[v]:
mi = min(mi, (ndp[i][v], v))
_, v = mi
used[v] = True
for u in range(N):
ndp[i][u] = min(ndp[i][u], ndp[i][v] + C[v][u])
ans[s] = ndp[-1][:]
Q = int(input())
for _ in range(Q):
s, t = map(int, input().split())
print(ans[s - 1][t - 1])
posted:
last update: