公式

D - 離島を結ぶ橋 / Bridges Connecting Remote Islands 解説 by admin

Qwen3-Coder-480B

概要

\(N\) 個の島と、それらを結ぶコスト付きの橋の候補が与えられる。すべての島を行き来可能にするために必要な最小コストを求めよ。不可能な場合は \(-1\) を出力せよ。

考察

この問題は、グラフの最小全域木(Minimum Spanning Tree, MST)を求める問題に帰着されます。
島を頂点、橋を辺、コストを重みとした無向グラフを考えると、「すべての島を行き来可能にする」=「グラフを連結にする」=「全域木を作る」と言い換えられます。さらに、コストを最小にしたいので「最小全域木」を求めることになります。

素朴な方法として、全辺を試して連結性を毎回確認する方法がありますが、辺の数が多い場合(最大 \(2 \times 10^5\))非常に非効率で、時間内に解けません。そこで、効率的に最小コストで連結性を管理するために、Union-Find(Disjoint Set Union) と呼ばれるデータ構造を用いるのが一般的です。

また、最小コストを実現するには、辺をコストの昇順にソートし、貪欲に選択していくことで構築できます(Kruskal法)。これにより、無駄な辺を選ばずに最小コストの全域木を得られます。

最後に、もし最終的に連結成分が1つにならなければ、すべての島を連結にすることは不可能なので -1 を出力します。

アルゴリズム

この問題では Kruskal のアルゴリズム を使用します。

手順:

  1. すべての辺をコストの昇順にソートする。
  2. Union-Find を用意し、初期状態ではすべての頂点が独立した集合に属しているとする。
  3. コストが小さい順に辺を見ていく。
    • 辺の両端が異なる集合に属している場合(= unite が成功する場合)、その辺を採用し、コストを加算する。
    • すべての頂点が一つの集合に属したら終了。
  4. 最終的に連結成分数が1でなければ -1 を出力。

計算量

  • 時間計算量: \(O(M \log M)\)
    (辺のソートに \(O(M \log M)\)、Union-Find操作がおおよそ \(O(\alpha(N))\) で、全体では支配項はソート)
  • 空間計算量: \(O(N + M)\)
    (Union-Findの管理領域と辺リストの分)

実装のポイント

  • 頂点番号は 0-indexed に変換しておく(\(1\) から \(N\) の入力を \(0\) から \(N-1\) にする)。

  • Union-Find は経路圧縮とランクによるマージ最適化を行うことで高速化される。

  • components 変数を使って連結成分の数を管理し、1 になった時点で終了することで無駄な処理を避ける。

  • グラフが最初から連結(\(N = 1\))の場合はコスト 0 でよいことに注意。

    ソースコード

class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n
        self.components = n

    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def unite(self, x, y):
        xr = self.find(x)
        yr = self.find(y)
        if xr != yr:
            if self.rank[xr] < self.rank[yr]:
                xr, yr = yr, xr
            self.parent[yr] = xr
            if self.rank[xr] == self.rank[yr]:
                self.rank[xr] += 1
            self.components -= 1
            return True
        return False

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    
    N = int(data[0])
    M = int(data[1])
    
    edges = []
    idx = 2
    for _ in range(M):
        u = int(data[idx]) - 1
        v = int(data[idx+1]) - 1
        c = int(data[idx+2])
        edges.append((c, u, v))
        idx += 3
    
    edges.sort()
    
    uf = UnionFind(N)
    total_cost = 0
    
    for cost, u, v in edges:
        if uf.unite(u, v):
            total_cost += cost
            if uf.components == 1:
                print(total_cost)
                return
    
    if N == 1:
        print(0)
    else:
        print(-1)

if __name__ == "__main__":
    main()

この解説は qwen3-coder-480b によって生成されました。

投稿日時:
最終更新: