Official

D - Minimum Steiner Tree Editorial by en_translator


If \(K=1\), the answer is obviously \(1\). We consider other cases.

Let us say a tree is applicable if it is obtained by removing zero or more edges and vertices from the original graph and it contains all of vertices \(V_1, \ldots, V_K\).

Also, let us say a vertex is bad if its degree is \(1\) and it is not any of \(V_1,\ldots,V_K\).

The sought tree is obtained by, starting from the original tree, repeatedly removing a bad vertex and its adjacent edge as many times as possible.

Proof 1. An applicable tree with the minimum number of vertices does not have a bad vertex
If an applicable tree has a bad vertex, then the tree remains applicable even after removing that vertex. Thus, an applicable tree with the minimum number of vertices does not have a bad vertex.
2. An applicable tree without a bad vertex is an applicable tree with the minimum number of vertices
Take an arbitrary applicable tree $T$ without a bad vertex, and an arbitrary applicable tree $T'$ with the minimum number of vertices.
In general, if we take two trees as subgraphs of a tree, then their intersection forms a tree or is empty. Consider the tree $T''$ obtained from $T$ by contracting the vertices in $T\cap T'$.
If $T''$ has two or more vertices, then $T''$ has two or more vertices of degree $1$. By assumption, $V_1,\ldots,V_K$ are all contracted, so at least one vertex of degree one is a bad vertex $T$, which violates the assumption.
Therefore, $T''$ has one vertex, which means $T\subset T'$. By the minimality of $T'$, we have $T=T'$.
Thus, it has been proved.

This operation can be done fast enough by managing the vertices directly connected by an edge for each vertex.

Writer’s solution (Python)

N,K = map(int,input().split())
edge = [set() for _ in range(N)]

for _ in range(N-1):
  a,b = map(int,input().split())
  a-=1
  b-=1
  edge[a].add(b)
  edge[b].add(a)

V = set(map(int,input().split()))
V = set(x-1 for x in V)

deg = [len(s) for s in edge]
q = [i for i,d in enumerate(deg) if d==1]

ans = N
for v in q:
  if v in V: continue
  vv = edge[v].pop()
  edge[vv].discard(v)
  ans-=1
  if len(edge[vv])==1: q.append(vv)

print(ans)

posted:
last update: