Official

F - Paint Tree 2 Editorial by en_translator


Regard the given tree \(T\) rooted at vertex \(1\).

Consider computing the answer by the following tree DP (Dynamic Programming):

  • Define \(d_c[i][k]\) as follows.
    • Let \(d_c[i][k]\) be the maximum sum of the integers written on any of \(k\) vertex-disjoint paths from the subtree \(i\). (Edge-disjoint means any two paths do not share the same vertex.)
    • Here, if \(c=0\), there is no additional constraints; if \(c=1\), one of the paths should have vertex \(i\) as an endpoint.

The sought answer is \(\displaystyle \max_k d_0[1][k]\).

The merge operation at each subtree can be done in \(O(K^2)\) time, so the answer can be found in a total of \(O(NK^2)\) time.

The merge operation is too complex to describe here, so please refer to the sample code. In the implementation below, we define another array \(d_2[i][k]\), which satisfies “one of the paths has a vertex \(i\) as a non-endpoint.” After evaluating \(d_0,d_1,d_2\), we apply a transition to absorb \(d_2\) into \(d_0\).

Sample code (Python)

import pypyjit
pypyjit.set_param('max_unroll_recursion=-1')
import sys
sys.setrecursionlimit(2 * 10**7)
input = sys.stdin.readline
n, k = map(int, input().split())
a = list(map(int, input().split()))
g = [[] for _ in range(n)]
for _ in range(n - 1):
    u, v = map(int, input().split())
    g[u - 1].append(v - 1)
    g[v - 1].append(u - 1)
INF = 10**18

def f(to, fr):
    d0 = [-INF] * (k + 1)
    d1 = [-INF] * (k + 1)
    d2 = [-INF] * (k + 1)
    d0[0] = 0
    for x in g[to]:
        if x == fr:
            continue
        new_d0 = [-INF] * (k + 1)
        new_d1 = [-INF] * (k + 1)
        new_d2 = [-INF] * (k + 1)
        di0, di1 = f(x, to)
        for i in range(k + 1):
            for j in range(k + 1):
                if i + j <= k:
                    new_d0[i + j] = max(new_d0[i + j], d0[i] + di0[j])
                    new_d1[i + j] = max(new_d1[i + j], d0[i] + di1[j] + a[to])
                    new_d1[i + j] = max(new_d1[i + j], d1[i] + di0[j])
                    new_d2[i + j] = max(new_d2[i + j], d2[i] + di0[j])
                    new_d1[i + j] = max(new_d1[i + j], d1[i] + di1[j])
                    new_d2[i + j] = max(new_d2[i + j], d2[i] + di1[j])
                if 0 <= i + j - 1 <= k:
                    new_d2[i + j - 1] = max(new_d2[i + j - 1], d1[i] + di1[j])
        d0, d1, d2 = new_d0, new_d1, new_d2
    for i in range(1, k + 1):
        d1[i] = max(d1[i], d0[i - 1] + a[to])
    for i in range(k + 1):
        d0[i] = max(d0[i], d1[i], d2[i])
    return d0, d1

d0, d1 = f(0, -1)
print(max(d0))

With appropriate pruning, the complexity can be reduced to \(O(NK)\). For more details, please refer to this article (in Japanese).

posted:
last update: