公式

D - ネットワークの構築 / Network Construction 解説 by admin

GPT 5.2 High

概要

必ず含めるべき \(K\) 本のケーブルを先に採用し、それらと矛盾(閉路)しない範囲で残りを最小コストで追加して全域木を完成させる問題です。

考察

全域木は「\(N\) 頂点を連結にし、閉路を作らず、辺数がちょうど \(N-1\)」という構造です。
ここで 指定されたケーブルは必ず使う必要があるため、通常の最小全域木(MST)よりも制約が強くなります。

重要な観察は次の2点です。

  1. 必須ケーブルだけで閉路ができたら不可能
    全域木には閉路が存在できません。したがって、必須ケーブルをすべて入れた時点で同じ連結成分内を結ぶ辺(=閉路を作る辺)が出たら、その時点で条件を満たす全域木は作れず、答えは \(-1\) です。

  2. 必須ケーブルを入れた状態からの最小追加は、Kruskal法でよい
    必須ケーブルを「すでに選ばれた辺」とみなし、その状態で残りの辺をコストの小さい順に見ていき、閉路を作らないものだけ採用すれば、追加分のコストは最小になります。
    これはKruskal法(最小全域木を作る貪欲法)が「今ある連結成分を最小コストで繋ぐ」戦略として最適であるためです。

素朴に「必須を含む全域木を全探索」すると、辺の選び方は組合せ爆発します(\(M\) が最大 \(2\times10^5\) なので到底不可能)。
そこで Union-Find(DSU)で連結性・閉路判定を高速に行い、Kruskal法の流れに乗せて解きます。

アルゴリズム

  1. 入力された辺を保持する(\((W_i, U_i, V_i, i)\) の形など)。
  2. 必須辺集合をフラグ配列 mandatory で管理する。
  3. Union-Find を初期化する。
  4. 必須辺を先にすべて追加する。
    • \((u,v)\) について union(u,v) が失敗(既に同一成分)なら、必須辺だけで閉路ができるため \(-1\)
    • 成功したらコストを加算し、使用辺数 used を増やす。
  5. 全辺をコスト昇順にソートし、必須でない辺だけを小さい順に見ていく。
    • used == N-1 になったら全域木完成なので終了。
    • mandatory な辺はスキップ(既に採用済み)。
    • union(u,v) が成功する(閉路にならない)なら採用してコスト加算、used++
  6. 最後に used != N-1 なら、そもそもグラフが連結にできない(必須制約のせいで繋げない等)ので \(-1\)。そうでなければ合計コストを出力。

(具体イメージ)
必須辺でいくつかの「島(連結成分)」ができたと考えます。その島同士を、安いケーブルから順に「島を繋げるものだけ」選んでいくと最小になります。

計算量

  • 時間計算量: ソートが支配的で \(O(M \log M)\)(Union-Find の操作は合計でほぼ \(O(M)\)
  • 空間計算量: 辺情報とUnion-Findで \(O(N+M)\)

実装のポイント

  • 必須辺の閉路チェックが最重要:必須辺を追加する段階で union 失敗なら即 \(-1\)

  • 必須辺を二重に数えない:ソート後の走査では mandatory[idx] を見て必須辺をスキップする。

  • 全域木の完成判定:採用辺数がちょうど \(N-1\) になったかで判定する(連結判定を別にやらなくてよい)。

  • 入力が大きいので、コードのように sys.stdin.buffer.read() でまとめて読むと高速です。

    ソースコード

import sys

class DSU:
    __slots__ = ("p", "sz")
    def __init__(self, n):
        self.p = list(range(n))
        self.sz = [1] * n

    def find(self, x):
        p = self.p
        while p[x] != x:
            p[x] = p[p[x]]
            x = p[x]
        return x

    def union(self, a, b):
        a = self.find(a)
        b = self.find(b)
        if a == b:
            return False
        if self.sz[a] < self.sz[b]:
            a, b = b, a
        self.p[b] = a
        self.sz[a] += self.sz[b]
        return True

def main():
    data = list(map(int, sys.stdin.buffer.read().split()))
    it = iter(data)
    try:
        N = next(it)
    except StopIteration:
        return
    M = next(it)
    K = next(it)

    U = [0] * (M + 1)
    V = [0] * (M + 1)
    W = [0] * (M + 1)
    edges = []
    for i in range(1, M + 1):
        u = next(it); v = next(it); w = next(it)
        U[i] = u; V[i] = v; W[i] = w
        edges.append((w, u, v, i))

    mandatory = [False] * (M + 1)
    mand_list = []
    for _ in range(K):
        idx = next(it)
        mandatory[idx] = True
        mand_list.append(idx)

    dsu = DSU(N + 1)
    total = 0
    used = 0

    for idx in mand_list:
        u, v, w = U[idx], V[idx], W[idx]
        if not dsu.union(u, v):
            print(-1)
            return
        total += w
        used += 1

    edges.sort(key=lambda x: x[0])
    for w, u, v, idx in edges:
        if used == N - 1:
            break
        if mandatory[idx]:
            continue
        if dsu.union(u, v):
            total += w
            used += 1

    if used != N - 1:
        print(-1)
    else:
        print(total)

if __name__ == "__main__":
    main()

この解説は gpt-5.2-high によって生成されました。

投稿日時:
最終更新: