Official

C - Forest Editorial by evima


Hints: https://atcoder.jp/contests/arc211/editorial/14696


We can assume there are no trees at both ends, so hereafter we assume there are no trees at both ends.

First, we can assume that the trees to be removed form exactly one interval. If you remove trees from two or more intervals at once, you can strictly increase the reward by removing them one at a time.

Let \(n\) be the number of intervals with trees in the initial state, and let \((r_1,r_2,\dots,r_{n+1})\) be the maximum values of \(R\) in the intervals without trees arranged in order. There are \(2n\) elements of \(R\) related to rewards including duplicates, of which \(n+1\) are included in \(r\). If we can make all remaining \(n-1\) elements the maximum value of \(R\), that is an upper bound of rewards. Is this achievable?

If we choose the location related to the maximum value of \(R\) in the first operation (including locations where removing the trees between them enables operating with the section with the maximum value of \(R\)), then by appropriately combining the maximum value of \(R\) there with those included in \(r\) in subsequent operations, it is possible to achieve the upper bound. Conversely, if you operate on a location not related to the maximum value of \(R\), by a similar argument, it can be seen that the upper bound becomes unachievable.

Therefore, for all locations related to the maximum value of \(R\), we should find the number of ways to operate there and sum them all. By run-length compressing in advance based on the presence or absence of trees, this is possible with time complexity \(O(N)\). Thus, the problem is solved.


Here are some implementation tips. By adding sections with trees at both ends, you can reduce case distinctions.

Note that the maximum value to be taken is not the maximum value of \(R\) (in its raw state), but the maximum value of \(R\) when sections with trees are removed from both ends.

The answer may not fit in a 32-bit integer range, so be careful if you are using a language like C++ that normally uses 32-bit integers.

Implementation Example (Python (Codon 0.19.3), 39ms)

N = int(input()) + 2
S = "#" + input() + "#"
R = [0] + list(map(int, input().split())) + [0]

now = "#"
mx = 0
cnt = 0
rle = []

for i in range(N):
    if S[i] == now:
        if mx == R[i]:
            cnt += 1
        elif mx < R[i]:
            mx = R[i]
            cnt = 1
    else:
        rle.append((mx, cnt))
        now = S[i]
        mx = R[i]
        cnt = 1

rle = rle[1:]
mx = max(rle)[0]
ans = 0

for i in range(0, len(rle) - 1, 2):
    if max(rle[i][0], rle[i + 1][0], rle[i + 2][0]) == mx:
        ans += rle[i][1] * rle[i + 2][1]

print(ans)

posted:
last update: