公式

E - 妨害と通信ネットワーク / Jamming and Communication Network 解説 by admin

Claude 4.6 Opus (Thinking)

概要

2-辺連結グラフにおいて、青木君が1本の辺を切断した後に高橋君が最小全域木(MST)を構築するというゲームで、青木君が最適に妨害したときのMSTコストの最大値を求める問題です。

考察

重要な気づき

まず、青木君が何も妨害しない場合の最小全域木を考えます。このMSTのコストを \(C_{\text{MST}}\) とします。

青木君が切断する辺には2つのケースがあります:

  1. MST に含まれない辺を切断した場合: MSTはそのまま構築できるので、コストは \(C_{\text{MST}}\) のまま変わりません。
  2. MST に含まれる辺 \(e\) を切断した場合: MSTが使えなくなり、辺 \(e\) の代わりに別の辺(代替辺)を使う必要があります。

青木君はコストを最大化したいので、MST の辺を切断する方が有利です。

MST の辺を1本除去したときの再構築

MST から辺 \(e\)(重み \(w_e\)、頂点 \(a\)-\(b\) を結ぶ)を除去すると、木が2つの連結成分に分かれます。再び全域木にするには、この2成分を繋ぐ非MST辺のうち最小重みのもの \(f\)(重み \(w_f\))で置き換えます。

このとき新しいMSTコストは \(C_{\text{MST}} - w_e + w_f\) となり、コスト増加分は \(w_f - w_e\) です。

青木君はこの増加分が最大となるMST辺 \(e\) を選ぶので、答えは:

\[C_{\text{MST}} + \max_{e \in \text{MST}} \left( \min_{f \text{ が } e \text{ をカバー}} w_f - w_e \right)\]

素朴なアプローチの問題点

各MST辺について、それをカバーする非MST辺の最小重みを愚直に求めると、各非MST辺についてパス上の全辺を列挙する必要があり、最悪 \(O(NM)\) でTLEします。

解決策:HLD(Heavy-Light Decomposition)

非MST辺 \((u, v, w)\) はMST上の \(u\)-\(v\) パスにある全てのMST辺を「カバー」します。これはパス上の区間に対する最小値更新と見なせます。HLDを使えば各パス更新を \(O(\log^2 N)\) で処理できます。

アルゴリズム

  1. Kruskal法でMSTを構築し、MSTコスト \(C_{\text{MST}}\) とMSTの辺集合を求める。
  2. MSTを根付き木として構築し、各辺を子ノードに対応づける(辺 \(u\)-\(v\)\(v\) が子なら、ノード \(v\) がその辺を表す)。
  3. HLD(Heavy-Light Decomposition)を構築する。各ノードにHLD上の位置を割り当てる。
  4. セグメント木(区間min更新・点クエリ)を用意し、初期値を \(\infty\) とする。
  5. 各非MST辺 \((u, v, w)\) について、MST上の \(u\)-\(v\) パスに対応するHLD区間に重み \(w\) で最小値更新を行う。
  6. 各MST辺(ルート以外の各ノード \(v\))について、セグメント木から \(\text{pos}[v]\) の値を取得し、それが最小代替辺の重み \(w_f\) となる。増加分 \(w_f - w_e\) の最大値を求める。
  7. 答えは \(C_{\text{MST}} + \max(w_f - w_e)\)

計算量

  • 時間計算量: \(O(M \log^2 N)\)(Kruskal \(O(M \log M)\) + 各非MST辺のHLDパス更新 \(O(M \log^2 N)\)
  • 空間計算量: \(O(N + M)\)

実装のポイント

  • 辺をノードに対応づける: 根付き木で各辺を子ノード側で管理する。パス更新時にLCAのノード自体は含めない(pos_in_hld[u]+1 から pos_in_hld[v]+1 の範囲)。

  • セグメント木は区間更新・点クエリ型: 通常のセグメント木とは逆で、更新が区間、クエリが点。遅延伝播不要で、更新時に該当するセグメント木ノードに最小値を記録し、クエリ時に葉から根まで辿って最小値を取れば十分。

  • 非MST辺のみ処理: MST辺はカバー対象なので、パス更新するのは非MST辺のみ。

    ソースコード

import sys
from collections import defaultdict

def main():
    input_data = sys.stdin.buffer.read().split()
    idx = 0
    N = int(input_data[idx]); idx += 1
    M = int(input_data[idx]); idx += 1
    
    edges = []
    for i in range(M):
        u = int(input_data[idx]); idx += 1
        v = int(input_data[idx]); idx += 1
        w = int(input_data[idx]); idx += 1
        edges.append((w, u, v, i))
    
    # Sort edges by weight
    sorted_edges = sorted(edges)
    
    # Kruskal's to find MST
    parent = list(range(N + 1))
    rank = [0] * (N + 1)
    
    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x
    
    def union(a, b):
        a, b = find(a), find(b)
        if a == b:
            return False
        if rank[a] < rank[b]:
            a, b = b, a
        parent[b] = a
        if rank[a] == rank[b]:
            rank[a] += 1
        return True
    
    mst_cost = 0
    mst_edge_indices = set()
    mst_adj = defaultdict(list)  # adjacency list for MST tree
    
    for w, u, v, i in sorted_edges:
        if union(u, v):
            mst_cost += w
            mst_edge_indices.add(i)
            mst_adj[u].append((v, w, i))
            mst_adj[v].append((u, w, i))
            if len(mst_edge_indices) == N - 1:
                break
    
    # For each MST edge, if we remove it, the MST splits into two components.
    # We need the minimum weight non-MST edge that reconnects them.
    # This is equivalent to: for each non-MST edge (u,v,w), it covers all MST edges
    # on the path from u to v in the MST. For each MST edge, the "replacement cost increase"
    # is min(non-MST edge weight crossing it) - (MST edge weight).
    
    # We need: for each MST edge e, find min weight of non-tree edges whose path covers e.
    # Then answer = mst_cost + max over MST edges e of (min_replacement[e] - weight[e])
    
    # Use HLD or Euler tour + LCA with segment tree for path min updates.
    # We'll use HLD approach: for each non-tree edge, update the path in HLD with its weight,
    # keeping track of minimums.
    
    # Build tree rooted at 1
    # BFS to set up parent, depth, subtree size
    root = 1
    tree_children = defaultdict(list)
    tree_parent = [0] * (N + 1)
    tree_depth = [0] * (N + 1)
    tree_edge_weight = [0] * (N + 1)  # weight of edge from node to its parent
    tree_edge_id = [0] * (N + 1)  # edge id from node to its parent
    
    visited = [False] * (N + 1)
    from collections import deque
    queue = deque([root])
    visited[root] = True
    order = []
    
    while queue:
        u = queue.popleft()
        order.append(u)
        for v, w, eid in mst_adj[u]:
            if not visited[v]:
                visited[v] = True
                tree_parent[v] = u
                tree_depth[v] = tree_depth[u] + 1
                tree_edge_weight[v] = w
                tree_edge_id[v] = eid
                tree_children[u].append(v)
                queue.append(v)
    
    # Compute subtree sizes
    subtree_size = [1] * (N + 1)
    for u in reversed(order):
        for c in tree_children[u]:
            subtree_size[u] += subtree_size[c]
    
    # HLD
    heavy_child = [0] * (N + 1)
    for u in order:
        best = -1
        best_size = 0
        for c in tree_children[u]:
            if subtree_size[c] > best_size:
                best_size = subtree_size[c]
                best = c
        heavy_child[u] = best
    
    chain_head = [0] * (N + 1)
    pos_in_hld = [0] * (N + 1)
    current_pos = 0
    
    # Assign HLD positions via DFS-like BFS respecting heavy child first
    stack = [(root, root)]
    while stack:
        u, head = stack.pop()
        chain_head[u] = head
        pos_in_hld[u] = current_pos
        current_pos += 1
        # Push light children first (so heavy child is processed first from stack)
        children = tree_children[u]
        light = []
        hc = heavy_child[u]
        for c in children:
            if c != hc:
                light.append(c)
        for c in reversed(light):
            stack.append((c, c))
        if hc != -1:
            stack.append((hc, head))
    
    # Segment tree for range min update, point query
    INF = float('inf')
    seg_size = 1
    while seg_size < N:
        seg_size *= 2
    seg = [INF] * (2 * seg_size)
    
    def seg_update(l, r, val):
        # Update [l, r) with min
        l += seg_size
        r += seg_size
        while l < r:
            if l & 1:
                if val < seg[l]:
                    seg[l] = val
                l += 1
            if r & 1:
                r -= 1
                if val < seg[r]:
                    seg[r] = val
            l >>= 1
            r >>= 1
    
    def seg_query(p):
        # Point query at p - get min of all ranges covering p
        p += seg_size
        res = INF
        while p >= 1:
            if seg[p] < res:
                res = seg[p]
            p >>= 1
        return res
    
    # LCA and path update using HLD
    def update_path(u, v, val):
        while chain_head[u] != chain_head[v]:
            if tree_depth[chain_head[u]] < tree_depth[chain_head[v]]:
                u, v = v, u
            # u's chain head is deeper
            seg_update(pos_in_hld[chain_head[u]], pos_in_hld[u] + 1, val)
            u = tree_parent[chain_head[u]]
        if u == v:
            return
        if tree_depth[u] > tree_depth[v]:
            u, v = v, u
        # u is ancestor, update edges from u's child down to v
        # Edge is represented by child node, so update pos_in_hld of nodes from depth(u)+1 to v
        # In HLD, that's [pos_in_hld[u]+1, pos_in_hld[v]+1)
        seg_update(pos_in_hld[u] + 1, pos_in_hld[v] + 1, val)
    
    # Process all non-tree edges
    for w, u, v, i in sorted_edges:
        if i not in mst_edge_indices:
            update_path(u, v, w)
    
    # For each MST edge (represented by child node), find replacement cost
    best_increase = 0
    for node in order:
        if node == root:
            continue
        eid = tree_edge_id[node]
        ew = tree_edge_weight[node]
        min_replacement = seg_query(pos_in_hld[node])
        # If we remove this MST edge, cost becomes mst_cost - ew + min_replacement
        increase = min_replacement - ew
        if increase > best_increase:
            best_increase = increase
    
    print(mst_cost + best_increase)

main()

この解説は claude4.6opus-thinking によって生成されました。

投稿日時:
最終更新: