Official

F - Paint Tree 2 Editorial by sounansya


与えられる木 \(T\) を頂点 \(1\) を根とする根付き木とします。

以下のような木 DP で答えを求めることを考えます。

  • \(d_c[i][k]\) を以下のように定める。
    • 部分木 \(i\) から \(k\) 本の点素パス(どの \(2\) つのパスも同じ頂点を共有しない)を取り、いずれかのパスに含まれる頂点に書かれた整数の総和の最大値を \(d_c[i][k]\) とする。
    • ただし、 \(c=0\) なら追加の制約はなく、 \(c=1\) ならいずれかのパスが端点として頂点 \(i\) を持つ必要がある。

\(\displaystyle \max_k d_0[1][k]\) が求める答えとなります。

各部分木のマージは \(O(K^2)\) 時間で行うことができるので、全体として \(O(NK^2)\) 時間で答えを求めることができます。

マージの式は複雑になるので、詳しくは実装例を参考にしてください。下の実装では \(d_2[i][k]\) を上の \(1\) つ目の定義に加え「いずれかのパスが端点以外の頂点として頂点 \(i\) を持つ」という条件を加えた \(3\) つの長さ \(m+1\) の配列 \(d_0,d_1,d_2\) を持ち、最後に \(d_2\)\(d_0\) に吸収させる遷移を書いています。

実装例(Python3)

import pypyjit
pypyjit.set_param('max_unroll_recursion=-1')
import sys
sys.setrecursionlimit(2 * 10**7)
input = sys.stdin.readline
n, k = map(int, input().split())
a = list(map(int, input().split()))
g = [[] for _ in range(n)]
for _ in range(n - 1):
    u, v = map(int, input().split())
    g[u - 1].append(v - 1)
    g[v - 1].append(u - 1)
INF = 10**18

def f(to, fr):
    d0 = [-INF] * (k + 1)
    d1 = [-INF] * (k + 1)
    d2 = [-INF] * (k + 1)
    d0[0] = 0
    for x in g[to]:
        if x == fr:
            continue
        new_d0 = [-INF] * (k + 1)
        new_d1 = [-INF] * (k + 1)
        new_d2 = [-INF] * (k + 1)
        di0, di1 = f(x, to)
        for i in range(k + 1):
            for j in range(k + 1):
                if i + j <= k:
                    new_d0[i + j] = max(new_d0[i + j], d0[i] + di0[j])
                    new_d1[i + j] = max(new_d1[i + j], d0[i] + di1[j] + a[to])
                    new_d1[i + j] = max(new_d1[i + j], d1[i] + di0[j])
                    new_d2[i + j] = max(new_d2[i + j], d2[i] + di0[j])
                    new_d1[i + j] = max(new_d1[i + j], d1[i] + di1[j])
                    new_d2[i + j] = max(new_d2[i + j], d2[i] + di1[j])
                if 0 <= i + j - 1 <= k:
                    new_d2[i + j - 1] = max(new_d2[i + j - 1], d1[i] + di1[j])
        d0, d1, d2 = new_d0, new_d1, new_d2
    for i in range(1, k + 1):
        d1[i] = max(d1[i], d0[i - 1] + a[to])
    for i in range(k + 1):
        d0[i] = max(d0[i], d1[i], d2[i])
    return d0, d1

d0, d1 = f(0, -1)
print(max(d0))

また、適切な枝刈りを行うことで計算量を \(O(NK)\) に削減することもできます。詳しくは こちらの記事 を参考にしてください。

posted:
last update: