H - Max × Sum 解説 by en_translator
For simplicity, we assume that \(A_i \leq A_{i+1}\) holds for all \(i\) \((1 \leq i \leq N-1)\). (Otherwise, you may sort the sequence of \((A,B)\) to achieve the condition without changing the answer.)
Let \(r\) \((K \leq r)\) be the maximum element in \(S\). Then, you need to choose \((K-1)\) elements from \(\lbrace 1, 2, \dots, r-1 \rbrace\) to put into \(S\), but \(\max_{i \in S} A_i\) equals \(A_r\) regardless of your choice; the values of \(A_i\) are no longer relevant, and we only need to minimize \(\sum_{i \in S} B_i\). The expression thus reduces to:
\[A_r \times \left(B_r + (\txt{The sum of }K-1\text{ smallest element among }B_1, B_2, \dots, B_{r-1})\right)\]
.
By scanning over \(r\), it turns out that the problem can be solved if we can obtain:
- the sum of the smallest \((K-1)\) elements among \(B_1, B_2, \dots, B_{r-1}\)
for all \(r=K, K+1, \dots, N\).
Actually, all these values can be computed using a priority queue. Define a multiset \(Q_n\) as follows:
- the multiset consisting of the smallest \(\min(K-1, n)\) elements among \(B_1, B_2, \dots, B_n\).
Then, \(Q_1, Q_2, \dots, Q_N\) can be found using differential update as follows. (Here, \(\lbrace \lbrace \rbrace \rbrace\) denotes a multiset.)
- \(Q_1 = \lbrace \lbrace B_1 \rbrace \rbrace\).
- \(Q_n = Q_{n-1} + \lbrace\lbrace B_n \rbrace\rbrace\) if \(2 \le n \leq K-1\).
- \(Q_n\) equals \(Q_{n-1} + \lbrace \lbrace B_{n-1} \rbrace \rbrace\) except for the largest value if\(K \leq n \leq N\).
Since we only need to support “pushing an element” and “popping the maximum value” to manage \(Q_n\), it can be processed fast by managing \(Q_n\) with a priority queue.
To find the answer, the sum of the elements in each of \(Q_{K-1}, Q_K, \dots, Q_N\) is required. This can be easily computed by separately managing the sum of pushed and popped elements meanwhile.
The problem can be solved by appropriately implementing the procedure described so far. The time complexity is \(\mathrm{O}(N \log N)\), which is fast enough.
- Sample code (Python)
import heapq
T = int(input())
for t in range(T):
N, K = map(int, input().split())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
C = sorted([(a, b) for a, b in zip(A, B)])
ans = 10**18
Q = []
bsum = 0
for i in range(N):
a, b = C[i]
if len(Q) == K - 1:
ans = min(ans, a * (bsum + b))
heapq.heappush(Q, -b)
bsum += b
if len(Q) > K - 1:
bsum -= -heapq.heappop(Q)
print(ans)
投稿日時:
最終更新: