公式

D - Shortest Path Queries 2 解説 by en_translator


The problem is extracted from the essence of the algorithm called Floyd–Warshall algorithm.

For convenience, we define \(f(s,t,0)\) as follows.

\[ f(s,t,0) = \begin{cases} 0 & s = t \\ \text{The weight of the edge }s\to t & \text{if the edge exists} \\ \mathrm{inf} & \text{if the edge doesn't exist} \end{cases} \]

Now let us express \(f(s,t,k+1)\) by means of \(f(\ast,\ast,k)\). The shortest path on \(s \to t\) when the vertices \(s\), \(t\), and those with index less than or equal to \(k+1\) satisfies either of the following two conditions depending on whether or not it uses \(k+1\). For each case, the shortest path can be calculated as follows.

  • If \(k+1\) is not used
    • The only vertices available are \(s\), \(t\), and those less than or equal to \(k\), so in this case the length of the shortest path is \(f(s,t,k)\).
  • If \(k+1\) is used
    • What we want is the shortest path from \(s\) to \(k+1\) only via those less than or equal to \(k\), and that from \(k+1\) to \(t\) only via those less than or equal to \(k\). Thus, the length of the shortest path is \(f(s,k+1,k) + f(k+1,t,k)\).

Therefore, we have the following relations between \(f(s,t,k+1)\) and \(f(\ast,\ast,k)\).

\[f(s,t,k+1) = \min(f(s,t,k) , f(s,k+1,k) + f(k+1,t,k))\]

With the DP (Dynamic Programming) based on the recurrence relations above, we have solved the problems in a total of \(\mathrm{O}(N^3)\) time.

Note that \(f(s, t, N)\) is the length of the shortest path from \(s\) to \(t\) when all the vertices is allowed to path through. This way, the shortest path from \(s\) to \(t\) for all \((s, t)\) can be computed in a total of \(\mathrm{O}(N^3)\) time; this is algorithm is called Floyd–Warshall algorithm.

A sample code in Python is pasted below.

import sys

N, M = map(int, sys.stdin.buffer.readline().split())
ABC = map(int, sys.stdin.buffer.read().split())
d = [[1 << 60] * N for i in range(N)]
for i in range(N):
  d[i][i] = 0
for a, b, c in zip(ABC, ABC, ABC):
  d[a - 1][b - 1] = c
answer = 0
for k in range(N):
  nxt = [[0] * N for i in range(N)]
  for i in range(N):
    for j in range(N):
      nxt[i][j] = min(d[i][j], d[i][k] + d[k][j])
      if nxt[i][j] < 1 << 59:
        answer += nxt[i][j]
  d = nxt
print(answer)

投稿日時:
最終更新: