D - Minimum Steiner Tree Editorial by en_translator
If \(K=1\), the answer is obviously \(1\). We consider other cases.
Let us say a tree is applicable if it is obtained by removing zero or more edges and vertices from the original graph and it contains all of vertices \(V_1, \ldots, V_K\).
Also, let us say a vertex is bad if its degree is \(1\) and it is not any of \(V_1,\ldots,V_K\).
The sought tree is obtained by, starting from the original tree, repeatedly removing a bad vertex and its adjacent edge as many times as possible.
Proof
1. An applicable tree with the minimum number of vertices does not have a bad vertexIf an applicable tree has a bad vertex, then the tree remains applicable even after removing that vertex. Thus, an applicable tree with the minimum number of vertices does not have a bad vertex.
2. An applicable tree without a bad vertex is an applicable tree with the minimum number of vertices
Take an arbitrary applicable tree $T$ without a bad vertex, and an arbitrary applicable tree $T'$ with the minimum number of vertices.
In general, if we take two trees as subgraphs of a tree, then their intersection forms a tree or is empty. Consider the tree $T''$ obtained from $T$ by contracting the vertices in $T\cap T'$.
If $T''$ has two or more vertices, then $T''$ has two or more vertices of degree $1$. By assumption, $V_1,\ldots,V_K$ are all contracted, so at least one vertex of degree one is a bad vertex $T$, which violates the assumption.
Therefore, $T''$ has one vertex, which means $T\subset T'$. By the minimality of $T'$, we have $T=T'$.
Thus, it has been proved.
This operation can be done fast enough by managing the vertices directly connected by an edge for each vertex.
Writer’s solution (Python)
N,K = map(int,input().split())
edge = [set() for _ in range(N)]
for _ in range(N-1):
a,b = map(int,input().split())
a-=1
b-=1
edge[a].add(b)
edge[b].add(a)
V = set(map(int,input().split()))
V = set(x-1 for x in V)
deg = [len(s) for s in edge]
q = [i for i,d in enumerate(deg) if d==1]
ans = N
for v in q:
if v in V: continue
vv = edge[v].pop()
edge[vv].discard(v)
ans-=1
if len(edge[vv])==1: q.append(vv)
print(ans)
posted:
last update: