Official

D - チェックポイントラリー / Checkpoint Rally Editorial by admin

Qwen3-Coder-480B

概要

与えられたグラフ上で、指定されたチェックポイントを順番に訪問した後、ゴールまで移動する最短時間を求める問題です。

考察

この問題は、単純に全ての頂点間の最短距離を求めたあと、指定された順序で移動するだけのように見えますが、頂点数 \(N\) が最大 \(5 \times 10^4\) と非常に大きいため、全点対最短距離(例えばワーシャル–フロイド法)を使うことはできません。

しかし、訪問するチェックポイント数 \(K\) は最大でも \(10\) と非常に小さいことに注目します。つまり、実際に最短距離が必要となるのは、スタート地点(頂点 \(1\))、チェックポイント(最大 \(10\) 個)、ゴール地点(頂点 \(N\))の合計高々 \(12\) 個の頂点間だけです。

したがって、これらの「必要な頂点」についてのみ、他の全頂点からの最短距離を前計算すれば十分です。これにより、無駄な計算を省き、効率的に解くことが可能です。

また、各頂点からの最短距離を求めるには、ダイクストラ法が適しており、 priority queue を用いることで高速に計算できます。

アルゴリズム

  1. グラフを隣接リストとして構築します。
  2. 必要な頂点のリストを作ります:これは [1] + P[1..K] + [N] です。
  3. このリスト内の各頂点 \(v\) に対して、ダイクストラ法を用いて全頂点への最短距離 dist[v][*] を求めます。
  4. 順番に訪問する頂点間の最短距離を合計し、全体の最短時間を求めます。
  5. いずれかの区間で到達不能(距離が \(\infty\))であれば -1 を出力します。

例えば、入力が以下の場合:

4 4 2
1 2 10
2 3 10
3 4 10
1 3 30
2 3
  • 訪問順は:1 → 2 → 3 → 4
  • 最短距離:
    • 1→2: 10
    • 2→3: 10
    • 3→4: 10
  • 合計: 30

計算量

  • 時間計算量: \(O(K \cdot (M \log N))\)
    • 必要な頂点数は最大 \(K+2\) 個で、それぞれに対しダイクストラ法を実行。
  • 空間計算量: \(O(N + M + K)\)
    • グラフの隣接リスト、距離配列、キューなどで使用されるメモリ。

実装のポイント

  • 入力を高速に読み込むために sys.stdin.read を使用しています。

  • 不要な頂点の最短距離計算を避けることで、計算量を大幅に削減しています。

  • ダイクストラ法の実装において、priority queue(heapq)の使い方に注意しましょう。

    ソースコード

import heapq
from collections import defaultdict
import sys

def dijkstra(graph, start, n):
    dist = [float('inf')] * (n + 1)
    dist[start] = 0
    pq = [(0, start)]
    while pq:
        d, u = heapq.heappop(pq)
        if d > dist[u]:
            continue
        for v, w in graph[u]:
            if dist[u] + w < dist[v]:
                dist[v] = dist[u] + w
                heapq.heappush(pq, (dist[v], v))
    return dist

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    
    idx = 0
    N = int(data[idx]); idx += 1
    M = int(data[idx]); idx += 1
    K = int(data[idx]); idx += 1
    
    graph = defaultdict(list)
    
    for _ in range(M):
        u = int(data[idx]); idx += 1
        v = int(data[idx]); idx += 1
        t = int(data[idx]); idx += 1
        graph[u].append((v, t))
        graph[v].append((u, t))
    
    if K == 0:
        checkpoints = []
    else:
        checkpoints = list(map(int, data[idx:idx+K]))
        idx += K
    
    # チェックポイントリスト(1 → P1 → P2 → ... → PK → N)
    points = [1] + checkpoints + [N]
    L = len(points)
    
    # 各ノードから全点対最短距離を前計算(必要なノードのみ)
    required_nodes = list(set(points))
    dist_dict = {}
    
    for node in required_nodes:
        dist_dict[node] = dijkstra(graph, node, N)
    
    # 経路の総コストを計算
    total_time = 0
    for i in range(L - 1):
        s = points[i]
        t = points[i+1]
        d = dist_dict[s][t]
        if d == float('inf'):
            print(-1)
            return
        total_time += d
    
    print(total_time)

if __name__ == "__main__":
    main()

この解説は qwen3-coder-480b によって生成されました。

posted:
last update: