公式

D - 配達ルートの最短時間 / Shortest Time for Delivery Route 解説 by admin

Claude 4.6 Opus (Thinking)

概要

地点 \(1\) から中継拠点 \(K\) を経由して地点 \(N\) へ至る最短経路を求める問題です。Dijkstra法を2回実行することで効率的に解くことができます。

考察

重要な気づき:経由地点がある最短経路の分解

「地点 \(1\) → 地点 \(K\) → 地点 \(N\)」という経由付き最短経路は、次の2つの部分に分解できます。

  1. 地点 \(1\) から地点 \(K\) への最短距離 \(d(1, K)\)
  2. 地点 \(K\) から地点 \(N\) への最短距離 \(d(K, N)\)

答えは \(d(1, K) + d(K, N)\) です。

これは、経由地点 \(K\) を必ず通るという制約があるため、\(K\) で経路を「切る」ことができるからです。\(1 \to K\) の最短経路と \(K \to N\) の最短経路をそれぞれ独立に求めても、全体の最短時間は変わりません。

素朴なアプローチの問題点

もし全頂点対間の最短距離を求めようとすると、Floyd-Warshall法では \(O(N^3)\) かかり、\(N\) が最大 \(2 \times 10^5\) の本問では到底間に合いません。

解決策

必要なのは「地点 \(1\) を始点とした最短距離」と「地点 \(K\) を始点とした最短距離」の2つだけです。Dijkstra法を2回実行すれば十分です。

  • 1回目:始点 \(1\) からの最短距離配列 → \(d(1, K)\) が得られる
  • 2回目:始点 \(K\) からの最短距離配列 → \(d(K, N)\) が得られる

アルゴリズム

  1. グラフを隣接リストで構築する。
  2. Dijkstra法(1回目): 地点 \(1\) を始点として、全頂点への最短距離を求める。ここから \(d(1, K)\) を取り出す。
  3. Dijkstra法(2回目): 地点 \(K\) を始点として、全頂点への最短距離を求める。ここから \(d(K, N)\) を取り出す。
  4. \(d(1, K)\) または \(d(K, N)\)\(\infty\)(到達不可能)であれば -1 を出力する。そうでなければ \(d(1, K) + d(K, N)\) を出力する。

具体例

例えば \(N=4, K=3\) で以下の道路があるとします:

  • \(1 \leftrightarrow 2\)(コスト \(2\)
  • \(2 \leftrightarrow 3\)(コスト \(3\)
  • \(3 \leftrightarrow 4\)(コスト \(1\)

このとき: - \(d(1, 3) = 2 + 3 = 5\)\(1 \to 2 \to 3\)) - \(d(3, 4) = 1\)\(3 \to 4\)) - 答え:\(5 + 1 = 6\)

計算量

  • 時間計算量: \(O((N + M) \log N)\)
    • Dijkstra法1回あたり \(O((N + M) \log N)\) で、これを2回実行するため。
  • 空間計算量: \(O(N + M)\)
    • グラフの隣接リストに \(O(N + M)\)、距離配列に \(O(N)\)

実装のポイント

  • 高速な入力: \(N, M\) が最大 \(2 \times 10^5\) と大きいため、sys.stdin.buffer.read() でまとめて読み込み、split して処理することでPythonでも十分高速に動作させています。

  • 距離の初期値: float('inf') を使うことで、到達不可能な頂点を簡潔に判定できます。

  • オーバーフローへの注意: 辺のコストが最大 \(10^9\)、経路長は最大で \(2 \times 10^5 \times 10^9 = 2 \times 10^{14}\) 程度になりえますが、Pythonでは整数のオーバーフローがないため心配不要です。

  • d > dist[u] による枝刈り: 優先度付きキューから取り出した距離が既に更新済みの距離より大きい場合はスキップすることで、不要な探索を避けています。

    ソースコード

import heapq
import sys

def dijkstra(graph, start, n):
    INF = float('inf')
    dist = [INF] * (n + 1)
    dist[start] = 0
    pq = [(0, start)]
    while pq:
        d, u = heapq.heappop(pq)
        if d > dist[u]:
            continue
        for v, w in graph[u]:
            if dist[u] + w < dist[v]:
                dist[v] = dist[u] + w
                heapq.heappush(pq, (dist[v], v))
    return dist

def main():
    input_data = sys.stdin.buffer.read().split()
    idx = 0
    N = int(input_data[idx]); idx += 1
    M = int(input_data[idx]); idx += 1
    K = int(input_data[idx]); idx += 1
    
    graph = [[] for _ in range(N + 1)]
    for _ in range(M):
        u = int(input_data[idx]); idx += 1
        v = int(input_data[idx]); idx += 1
        c = int(input_data[idx]); idx += 1
        graph[u].append((v, c))
        graph[v].append((u, c))
    
    dist_from_1 = dijkstra(graph, 1, N)
    dist_from_k = dijkstra(graph, K, N)
    
    INF = float('inf')
    d1k = dist_from_1[K]
    dkn = dist_from_k[N]
    
    if d1k == INF or dkn == INF:
        print(-1)
    else:
        print(d1k + dkn)

main()

この解説は claude4.6opus-thinking によって生成されました。

投稿日時:
最終更新: