Official

D - チェックポイントラリー / Checkpoint Rally Editorial by admin

(非推奨) Claude 4.5 Opus

概要

スタート地点から指定された \(K\) 個のチェックポイントを順番に訪問し、ゴールに到達するまでの最短時間を求める問題です。各区間でダイクストラ法を用いて最短経路を計算します。

考察

問題の整理

この問題では、以下の順序で移動する必要があります: $\(1 \rightarrow P_1 \rightarrow P_2 \rightarrow \cdots \rightarrow P_K \rightarrow N\)$

つまり、全体の移動時間は各区間の最短距離の合計になります。

素朴なアプローチの問題点

毎回の区間移動ごとにダイクストラ法を実行すると、\(K+1\) 回のダイクストラ法が必要です。しかし、同じチェックポイントが複数回出現する場合(例:\(P_1 = P_3\) など)、同じ計算を繰り返すことになり無駄です。

効率化のポイント

重要な気づき:必要なチェックポイントは高々 \(K+2\) 個(スタート、\(K\) 個の中間点、ゴール)です。\(K \leq 10\) なので、最大でも \(12\) 個しかありません。

したがって、重複を除いた各チェックポイントからダイクストラ法を1回ずつ実行しておけば、任意の2点間の最短距離がすぐに分かります。

具体例

例えば、\(K=2\) で訪問順序が \(1 \rightarrow 3 \rightarrow 5 \rightarrow N\) の場合: - チェックポイント \(1, 3, 5, N\) それぞれからダイクストラ法を実行 - \(\text{dist}[1][3] + \text{dist}[3][5] + \text{dist}[5][N]\) を計算

アルゴリズム

  1. グラフの構築: 入力された \(M\) 本の道から隣接リストを作成
  2. チェックポイントリストの作成: \([1, P_1, P_2, \ldots, P_K, N]\) を作成
  3. ダイクストラ法の実行: リスト内の各チェックポイント(重複除去後)を始点としてダイクストラ法を実行し、全頂点への最短距離を計算
  4. 総移動時間の計算: 順番に隣接するチェックポイント間の最短距離を足し合わせる
  5. 到達不可能の判定: いずれかの区間で到達不可能(距離が \(\infty\))なら -1 を出力
checkpoints = [1, P_1, P_2, ..., P_K, N]
total_time = dist[1][P_1] + dist[P_1][P_2] + ... + dist[P_K][N]

計算量

  • 時間計算量: \(O((K+2) \cdot (M + N) \log N)\)

    • ダイクストラ法1回: \(O((M + N) \log N)\)
    • 実行回数: 最大 \(K+2\) 回(重複除去後)
    • \(K \leq 10\) なので、実質 \(O(M \log N)\) 程度
  • 空間計算量: \(O((K+2) \cdot N + M)\)

    • 各チェックポイントからの距離配列: \(O(N)\) が最大 \(K+2\)
    • グラフの隣接リスト: \(O(M)\)

実装のポイント

  1. 重複除去の活用: set(checkpoints) を使って、同じチェックポイントからのダイクストラ法を複数回実行しないようにする

  2. 距離の初期値: float('inf') を使用し、到達不可能な場合を正しく判定する

  3. 辞書による距離管理: dist_from[cp] に各チェックポイント cp からの最短距離配列を格納し、効率的にアクセスする

  4. 早期終了: いずれかの区間で到達不可能と分かった時点で -1 を出力して終了する

    ソースコード

import heapq
from collections import defaultdict

def dijkstra(graph, start, n):
    dist = [float('inf')] * (n + 1)
    dist[start] = 0
    pq = [(0, start)]
    
    while pq:
        d, u = heapq.heappop(pq)
        if d > dist[u]:
            continue
        for v, w in graph[u]:
            if dist[u] + w < dist[v]:
                dist[v] = dist[u] + w
                heapq.heappush(pq, (dist[v], v))
    
    return dist

def main():
    import sys
    input = sys.stdin.readline
    
    N, M, K = map(int, input().split())
    
    graph = defaultdict(list)
    for _ in range(M):
        u, v, t = map(int, input().split())
        graph[u].append((v, t))
        graph[v].append((u, t))
    
    P = list(map(int, input().split()))
    
    # 訪問する必要がある頂点のリスト: 1, P_1, P_2, ..., P_K, N
    checkpoints = [1] + P + [N]
    
    # 各チェックポイントからの最短距離を計算
    dist_from = {}
    for cp in set(checkpoints):
        dist_from[cp] = dijkstra(graph, cp, N)
    
    # 経路に沿って最小時間を計算
    total_time = 0
    for i in range(len(checkpoints) - 1):
        src = checkpoints[i]
        dst = checkpoints[i + 1]
        d = dist_from[src][dst]
        if d == float('inf'):
            print(-1)
            return
        total_time += d
    
    print(total_time)

if __name__ == "__main__":
    main()

この解説は claude4.5opus によって生成されました。

posted:
last update: