Official

D - ネットワークの構築 / Network Construction Editorial by admin

Claude 4.6 Opus (Thinking)

概要

\(N\) 台のサーバーを全域木で接続するとき、指定された \(K\) 本のケーブルを必ず含みつつ、敷設コストの合計を最小化する問題です。指定ケーブルを含む全域木が存在しない場合は \(-1\) を出力します。

考察

重要な気づき

この問題は「一部の辺が強制的に選ばれた状態での最小全域木」を求める問題です。

通常の最小全域木(MST)はクラスカル法で求められます。クラスカル法は「コストの小さい辺から順に、閉路を作らないものを採用する」というアルゴリズムです。

ここで重要なのは、必須の辺を先に全て追加してから、残りの辺でクラスカル法を行えばよいという点です。

なぜこれで正しいのか

必須辺を追加した後の状態を考えます。Union-Find で連結成分が管理されており、いくつかの「グループ」ができています。この後に必要なのは、これらのグループを全て1つに繋ぐことです。これはまさに、必須辺で縮約されたグラフ上での最小全域木問題に帰着されます。したがって、残りの辺をコスト順に見て、異なる連結成分を繋ぐものを貪欲に採用する(クラスカル法)のが最適です。

不可能なケース

以下の2つの場合に答えは \(-1\) です:

  1. 必須辺同士が閉路を作る場合: \(K\) 本の必須辺を追加する途中で、既に同じ連結成分に属する2頂点を繋ぐ辺が現れると、全域木(閉路なし)にできません。
  2. 全域木を完成できない場合: 必須辺と追加の辺を合わせても \(N-1\) 本に満たない(グラフ全体が非連結)場合です。

アルゴリズム

  1. 入力を読み取り、必須辺の集合を管理する。
  2. Union-Find を初期化する(\(N\) 頂点)。
  3. 必須辺を全て追加する。追加時に両端が既に同じ連結成分なら閉路が生じるので \(-1\) を出力して終了。コストを累計する。
  4. 非必須辺をコスト昇順にソートする。
  5. クラスカル法で、非必須辺を小さい順に見て、異なる連結成分を繋ぐ辺を採用していく。採用した辺の本数が合計 \(N-1\) 本になったら終了。
  6. 最終的に \(N-1\) 本の辺が選ばれていれば合計コストを出力、そうでなければ \(-1\) を出力。

具体例

\(N=4\), サーバー間に辺 \((1,2,10), (2,3,5), (3,4,3), (1,4,1)\) があり、辺 \((1,2,10)\) が必須の場合:

  • まず辺 \((1,2,10)\) を追加 → \(\{1,2\}, \{3\}, \{4\}\)
  • 残りをコスト順: \((1,4,1), (3,4,3), (2,3,5)\)
  • \((1,4,1)\) を採用 → \(\{1,2,4\}, \{3\}\)
  • \((3,4,3)\) を採用 → \(\{1,2,3,4\}\) → 完成!
  • 合計コスト: \(10 + 1 + 3 = 14\)

計算量

  • 時間計算量: \(O(M \log M)\)(非必須辺のソートがボトルネック。Union-Find の操作はほぼ \(O(\alpha(N))\) で定数的)
  • 空間計算量: \(O(N + M)\)

実装のポイント

  • Union-Find の経路圧縮とランクによる統合を使うことで、union/find 操作を高速に行う。

  • 必須辺の集合を set で管理し、\(O(1)\) で必須かどうかを判定する。

  • 辺のインデックスは問題文で 1-indexed だが、内部では 0-indexed に変換して扱う。

  • \(K = 0\) のとき最終行(必須辺リスト)が存在しないので、入力処理で場合分けが必要。

    ソースコード

import sys
input = sys.stdin.readline

def main():
    N, M, K = map(int, input().split())
    edges = []
    for i in range(M):
        u, v, w = map(int, input().split())
        edges.append((u, v, w))
    
    required = set()
    if K > 0:
        es = list(map(int, input().split()))
        for e in es:
            required.add(e - 1)  # 0-indexed
    
    # Union-Find
    parent = list(range(N + 1))
    rank = [0] * (N + 1)
    
    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x
    
    def union(x, y):
        rx, ry = find(x), find(y)
        if rx == ry:
            return False
        if rank[rx] < rank[ry]:
            rx, ry = ry, rx
        parent[ry] = rx
        if rank[rx] == rank[ry]:
            rank[rx] += 1
        return True
    
    total_cost = 0
    
    # Step 1: Add all required edges
    for idx in required:
        u, v, w = edges[idx]
        if not union(u, v):
            # Required edges form a cycle -> impossible
            print(-1)
            return
        total_cost += w
    
    # Step 2: Check if required edges alone don't disconnect things that can't be fixed
    # Now do Kruskal on remaining edges to complete the spanning tree
    # Sort non-required edges by weight
    non_required = []
    for i in range(M):
        if i not in required:
            non_required.append((edges[i][2], edges[i][0], edges[i][1]))
    
    non_required.sort()
    
    edges_used = K
    
    for w, u, v in non_required:
        if edges_used == N - 1:
            break
        if union(u, v):
            total_cost += w
            edges_used += 1
    
    if edges_used == N - 1:
        print(total_cost)
    else:
        print(-1)

main()

この解説は claude4.6opus-thinking によって生成されました。

posted:
last update: