D - チェックポイントラリー / Checkpoint Rally 解説 by admin
Qwen3-Coder-480B概要
与えられたグラフ上で、指定されたチェックポイントを順番に訪問した後、ゴールまで移動する最短時間を求める問題です。
考察
この問題は、単純に全ての頂点間の最短距離を求めたあと、指定された順序で移動するだけのように見えますが、頂点数 \(N\) が最大 \(5 \times 10^4\) と非常に大きいため、全点対最短距離(例えばワーシャル–フロイド法)を使うことはできません。
しかし、訪問するチェックポイント数 \(K\) は最大でも \(10\) と非常に小さいことに注目します。つまり、実際に最短距離が必要となるのは、スタート地点(頂点 \(1\))、チェックポイント(最大 \(10\) 個)、ゴール地点(頂点 \(N\))の合計高々 \(12\) 個の頂点間だけです。
したがって、これらの「必要な頂点」についてのみ、他の全頂点からの最短距離を前計算すれば十分です。これにより、無駄な計算を省き、効率的に解くことが可能です。
また、各頂点からの最短距離を求めるには、ダイクストラ法が適しており、 priority queue を用いることで高速に計算できます。
アルゴリズム
- グラフを隣接リストとして構築します。
- 必要な頂点のリストを作ります:これは
[1] + P[1..K] + [N]です。 - このリスト内の各頂点 \(v\) に対して、ダイクストラ法を用いて全頂点への最短距離
dist[v][*]を求めます。 - 順番に訪問する頂点間の最短距離を合計し、全体の最短時間を求めます。
- いずれかの区間で到達不能(距離が \(\infty\))であれば
-1を出力します。
例
例えば、入力が以下の場合:
4 4 2
1 2 10
2 3 10
3 4 10
1 3 30
2 3
- 訪問順は:1 → 2 → 3 → 4
- 最短距離:
- 1→2: 10
- 2→3: 10
- 3→4: 10
- 合計: 30
計算量
- 時間計算量: \(O(K \cdot (M \log N))\)
- 必要な頂点数は最大 \(K+2\) 個で、それぞれに対しダイクストラ法を実行。
- 空間計算量: \(O(N + M + K)\)
- グラフの隣接リスト、距離配列、キューなどで使用されるメモリ。
実装のポイント
入力を高速に読み込むために
sys.stdin.readを使用しています。不要な頂点の最短距離計算を避けることで、計算量を大幅に削減しています。
ダイクストラ法の実装において、priority queue(
heapq)の使い方に注意しましょう。ソースコード
import heapq
from collections import defaultdict
import sys
def dijkstra(graph, start, n):
dist = [float('inf')] * (n + 1)
dist[start] = 0
pq = [(0, start)]
while pq:
d, u = heapq.heappop(pq)
if d > dist[u]:
continue
for v, w in graph[u]:
if dist[u] + w < dist[v]:
dist[v] = dist[u] + w
heapq.heappush(pq, (dist[v], v))
return dist
def main():
import sys
input = sys.stdin.read
data = input().split()
idx = 0
N = int(data[idx]); idx += 1
M = int(data[idx]); idx += 1
K = int(data[idx]); idx += 1
graph = defaultdict(list)
for _ in range(M):
u = int(data[idx]); idx += 1
v = int(data[idx]); idx += 1
t = int(data[idx]); idx += 1
graph[u].append((v, t))
graph[v].append((u, t))
if K == 0:
checkpoints = []
else:
checkpoints = list(map(int, data[idx:idx+K]))
idx += K
# チェックポイントリスト(1 → P1 → P2 → ... → PK → N)
points = [1] + checkpoints + [N]
L = len(points)
# 各ノードから全点対最短距離を前計算(必要なノードのみ)
required_nodes = list(set(points))
dist_dict = {}
for node in required_nodes:
dist_dict[node] = dijkstra(graph, node, N)
# 経路の総コストを計算
total_time = 0
for i in range(L - 1):
s = points[i]
t = points[i+1]
d = dist_dict[s][t]
if d == float('inf'):
print(-1)
return
total_time += d
print(total_time)
if __name__ == "__main__":
main()
この解説は qwen3-coder-480b によって生成されました。
投稿日時:
最終更新: