Official

E - 配達ドライバーの巡回 / Delivery Driver's Route Editorial by admin

Claude 4.6 Opus (Thinking)

概要

\(1\) から出発してすべての町を少なくとも \(1\) 回訪問し、町 \(1\) に戻る最短時間を求める問題です。これは巡回セールスマン問題 (TSP) の一種であり、\(N \leq 15\) という制約からビットマスク DP で解けます。

考察

重要な気づき

  1. 同じ道路・町を何度でも通れるため、町 \(u\) から町 \(v\) へ移動するコストは「\(u\) から \(v\) への最短距離」と考えてよいです。途中で他の町を経由しても構いません。つまり、まず全頂点対間の最短距離を求めておけば、任意の2町間を最短距離で直接移動できる完全グラフ上の問題に帰着できます。

  2. 帰着後の問題は「完全グラフ上で町 \(1\) を出発し、すべての町を \(1\) 回以上訪問して町 \(1\) に戻る最短経路」です。これはまさに巡回セールスマン問題です。

素朴なアプローチの問題

全ての訪問順序を試す場合、順列の数は \((N-1)!\) 通りです。\(N = 15\) のとき \((14)! \approx 8.7 \times 10^{10}\) となり、とても間に合いません。

解決策

ビットマスク DP を使います。「どの町の集合を訪問済みか」をビットマスクで管理することで、状態数を \(2^N \times N\) に抑えられます。\(N = 15\) のとき \(2^{15} \times 15 = 491520\) 程度であり、十分高速です。

アルゴリズム

ステップ 1: 全頂点対間最短距離(Floyd-Warshall法)

隣接行列 \(\text{dist}[i][j]\) を初期化し、Floyd-Warshall法で全ての町のペア間の最短距離を求めます。

\[\text{dist}[i][j] = \min(\text{dist}[i][j],\ \text{dist}[i][k] + \text{dist}[k][j])\]

を全ての中継点 \(k\) について更新します。

この段階で町 \(0\)(= 町1)から到達できない町があれば、答えは -1 です。

ステップ 2: ビットマスク DP(TSP)

  • 状態: \(\text{dp}[S][u]\) = 訪問済みの町の集合が \(S\)(ビットマスク)で、現在町 \(u\) にいるときの最小移動時間
  • 初期状態: \(\text{dp}[\{0\}][0] = 0\)(町 \(0\) のみ訪問済み、町 \(0\) にいる、コスト \(0\)
  • 遷移: 町 \(u\) から町 \(v\) に移動する

\[\text{dp}[S \cup \{v\}][v] = \min(\text{dp}[S \cup \{v\}][v],\ \text{dp}[S][u] + \text{dist}[u][v])\]

  • 答え: すべての町を訪問した状態 \(S = \text{full\_mask}\) に対して

\[\text{ans} = \min_{u} (\text{dp}[\text{full\_mask}][u] + \text{dist}[u][0])\]

つまり、全町訪問後に町 \(0\) に戻るコストを加えた最小値です。

計算量

  • 時間計算量: \(O(N^3 + 2^N \cdot N^2)\)
    • Floyd-Warshall: \(O(N^3)\)
    • ビットマスク DP: 状態数 \(O(2^N \cdot N)\) に対し各状態から \(N\) 通りの遷移 → \(O(2^N \cdot N^2)\)
    • \(N = 15\) のとき \(2^{15} \times 15^2 \approx 7.4 \times 10^6\) で十分高速
  • 空間計算量: \(O(2^N \cdot N)\)

実装のポイント

  • 町番号は入力で \(1\)-indexed だが、内部では \(0\)-indexed に変換すると扱いやすい

  • Floyd-Warshall の前に dist[i][i] = 0 を忘れずに初期化する

  • DP の遷移で S & (1 << u) を確認し、現在地 \(u\) が訪問済み集合に含まれているかチェックする(不正な状態からの遷移を防ぐ)

  • 最終的な答えでは、全町訪問後の現在地から町 \(0\) への帰還コストを忘れずに加算する

    ソースコード

import heapq
import sys

def solve():
    input_data = sys.stdin.read().split()
    idx = 0
    N = int(input_data[idx]); idx += 1
    M = int(input_data[idx]); idx += 1
    
    INF = float('inf')
    
    # Floyd-Warshall for all-pairs shortest paths
    dist = [[INF] * N for _ in range(N)]
    for i in range(N):
        dist[i][i] = 0
    
    for _ in range(M):
        u = int(input_data[idx]) - 1; idx += 1
        v = int(input_data[idx]) - 1; idx += 1
        w = int(input_data[idx]); idx += 1
        if w < dist[u][v]:
            dist[u][v] = w
            dist[v][u] = w
    
    for k in range(N):
        for i in range(N):
            if dist[i][k] == INF:
                continue
            for j in range(N):
                if dist[k][j] == INF:
                    continue
                if dist[i][k] + dist[k][j] < dist[i][j]:
                    dist[i][j] = dist[i][k] + dist[k][j]
    
    # Check if all towns are reachable from town 0
    for i in range(N):
        if dist[0][i] == INF:
            print(-1)
            return
    
    # TSP with bitmask DP
    # dp[S][i] = minimum cost to have visited the set of towns S, currently at town i
    # S is a bitmask of visited towns
    # We start at town 0, need to visit all towns, and return to town 0
    
    full_mask = (1 << N) - 1
    dp = [[INF] * N for _ in range(1 << N)]
    dp[1 << 0][0] = 0
    
    for S in range(1 << N):
        for u in range(N):
            if dp[S][u] == INF:
                continue
            if not (S & (1 << u)):
                continue
            for v in range(N):
                if dist[u][v] == INF:
                    continue
                nS = S | (1 << v)
                cost = dp[S][u] + dist[u][v]
                if cost < dp[nS][v]:
                    dp[nS][v] = cost
    
    ans = INF
    for u in range(N):
        if dp[full_mask][u] == INF:
            continue
        if dist[u][0] == INF:
            continue
        ans = min(ans, dp[full_mask][u] + dist[u][0])
    
    if ans == INF:
        print(-1)
    else:
        print(ans)

solve()

この解説は claude4.6opus-thinking によって生成されました。

posted:
last update: