公式

E - 配送ルートの最適化 / Optimization of Delivery Routes 解説 by admin

GPT 5.2 High

概要

配送センター(地点 \(1\))から出発して全地点をちょうど1回ずつ巡り、最後に地点 \(1\) に戻るときの総コスト(移動ごとの二乗距離の和)を最小化する、典型的な巡回セールスマン問題(TSP)を \(N \le 16\) の範囲で解きます。

考察

  • すべての訪問順序を試す素朴解は、順列の数が \((N-1)!\) 通りあり、例えば \(N=16\) だと \(15! \approx 1.3 \times 10^{12}\) で現実的に計算できません(TLE)。
  • しかし \(N \le 16\) と小さいため、「訪問した集合」をビットマスクで管理する bit DP(Held-Karp 法) が使えます。
  • コストは \((a-c)^2+(b-d)^2\)(二乗距離)で、三角不等式が成り立つ必要はありません。TSP の bit DP は「部分集合の最適」を積み上げるだけなので、このコストでも正しく最小値を求められます。
  • スタート地点は必ず \(1\) なので、地点 \(1\) を除いた \(N-1\) 個だけを部分集合 DP の対象にすると状態数を減らせます。

アルゴリズム

地点 \(1\) をスタート(インデックス \(0\))とし、残りの地点を \(0 \ldots n-1\)(ただし \(n=N-1\))として扱います。

1. 距離(コスト)の前計算

  • \(d0[i]\):地点 \(1\) → 地点 \(i\)\(i\) は「残りの地点」の番号)のコスト
  • \(d[i][j]\):残りの地点 \(i\) → 残りの地点 \(j\) のコスト
  • 戻りコスト \(dback[i]\) は地点 \(i\) → 地点 \(1\) で、ここでは \(d0[i]\) と同じ値です(二乗距離は対称)。

毎回 \((x,y)\) から二乗距離を計算すると遷移が重くなるので、先に全部計算しておきます。

2. DP 状態

\(dp[mask][i]\) を次のように定義します: - \(mask\):訪問済みの「残りの地点」の集合(ビットマスク、サイズは \(2^n\)) - \(i\):最後にいる地点(残りの地点のどれか) - 値:地点 \(1\) を出発して、\(mask\) に含まれる地点をすべて1回ずつ訪問し、最後に地点 \(i\) にいるときの最小コスト

初期化: - \(dp[1 \ll i][i] = d0[i]\)(地点 \(1\) から直接 \(i\) に行く)

遷移: - \(dp[mask \cup \{j\}][j] = \min\left(dp[mask \cup \{j\}][j],\ dp[mask][i] + d[i][j]\right)\)
(ただし \(j \notin mask\)

答え: - 全地点訪問後(\(mask = (1 \ll n)-1\))に地点 \(1\) に戻るので
\(\min_i \left(dp[all][i] + dback[i]\right)\)

3. コード上の工夫(高速化)

  • \(dp\) を 2 次元配列ではなく 1 次元に潰し、base[mask] = mask * n のようなオフセットで dp[base[mask] + i] としてアクセスしています。
  • ビット演算で「集合の要素列挙」を行い、lsb = m & -m(最下位ビット)を使って高速に走査しています。
  • bit_index[1<<i] = i を用意し、「立っているビット → 頂点番号」を \(O(1)\) で引けるようにしています。

計算量

  • 時間計算量: \(O(n^2 2^n)\)(ここで \(n=N-1\)
    各部分集合 \(mask\)\(2^n\) 通り)について、終点 \(i\)(最大 \(n\) 通り)から未訪問 \(j\)(最大 \(n\) 通り)へ遷移するためです。
  • 空間計算量: \(O(n 2^n)\)
    \(dp[mask][i]\) を全状態持つためです。

実装のポイント

  • 地点 \(1\) を除外して DP:スタートが固定なので、部分集合は残り \(N-1\) 個だけにすると実装がシンプルかつ高速です。

  • 二乗距離は必ず整数\((a-c)^2+(b-d)^2\) は整数なので、誤差の心配なく整数 DP ができます。

  • \(N=2\) の特別扱い:訪問順が実質 1→2→1 のみなので、コストは \(2 \times \text{dist}^2(1,2)\) を即座に出せます(コードでも分岐しています)。

  • INF の設定:最大コストは十分小さい(座標差は最大 \(2000\)、二乗で \(4\times10^6\)、辺は最大 \(16\) 本程度)ので、INF=10**18 のように大きく取れば安全です。

    ソースコード

import sys
from array import array

def main():
    input = sys.stdin.readline
    N = int(input())
    xs = [0] * N
    ys = [0] * N
    for i in range(N):
        x, y = map(int, input().split())
        xs[i] = x
        ys[i] = y

    if N == 2:
        dx = xs[0] - xs[1]
        dy = ys[0] - ys[1]
        c = dx * dx + dy * dy
        print(2 * c)
        return

    n = N - 1  # nodes excluding start(0), indices 0..n-1 correspond to original 1..N-1
    x0, y0 = xs[0], ys[0]

    d0 = [0] * n
    dback = [0] * n
    d = [[0] * n for _ in range(n)]

    for i in range(n):
        dx = x0 - xs[i + 1]
        dy = y0 - ys[i + 1]
        c = dx * dx + dy * dy
        d0[i] = c
        dback[i] = c

    for i in range(n):
        xi, yi = xs[i + 1], ys[i + 1]
        row = d[i]
        for j in range(n):
            dx = xi - xs[j + 1]
            dy = yi - ys[j + 1]
            row[j] = dx * dx + dy * dy

    full = 1 << n
    fullmask = full - 1

    base = [m * n for m in range(full)]
    bit_index = [0] * full
    for i in range(n):
        bit_index[1 << i] = i

    INF = 10**18
    dp = array('Q', [INF]) * (full * n)

    for i in range(n):
        dp[base[1 << i] + i] = d0[i]

    for mask in range(1, full):
        base_mask = base[mask]
        rem0 = fullmask ^ mask
        m = mask
        while m:
            lsb = m & -m
            i = bit_index[lsb]
            cur = dp[base_mask + i]
            if cur != INF:
                row = d[i]
                r = rem0
                while r:
                    lb = r & -r
                    j = bit_index[lb]
                    nm = mask | lb
                    idx = base[nm] + j
                    val = cur + row[j]
                    if val < dp[idx]:
                        dp[idx] = val
                    r -= lb
            m -= lsb

    ans = INF
    base_all = base[fullmask]
    for i in range(n):
        cur = dp[base_all + i]
        val = cur + dback[i]
        if val < ans:
            ans = val

    print(ans)

if __name__ == "__main__":
    main()

この解説は gpt-5.2-high によって生成されました。

投稿日時:
最終更新: