I - Maximum Composition Editorial
by
MtSaka
数列 \(p\) に使われる \(K\) 個の整数を定めたとき、最大値を取るような \(p\) にはどのような性質があるかを考えます。この性質を考える時に、以下の事実が重要になります。
\(f_i(f_j(x))\) と\(f_j(f_i(x))\) の大小は \(x\) によらない。また、この大小は\(\frac{A_i-1}{B_i}\) と\(\frac{A_j-1}{B_j}\) の大小による。\(\frac{A_i-1}{B_i}>\frac{A_j-1}{B_j}\) ならば \(f_i(f_j(x))>f_j(f_i(x))\) である。
この性質より、\(p\) に用いられている \(K\) 個の整数を定めたとき、\(\frac{A_{p_1}-1}{B_{p_1}} \geq \frac{A_{p_2}-1}{B_{p_2}} \geq \ldots \geq \frac{A_{p_K}-1}{B_{p_K}}\) を満たす \(p\)が \(f_{p_1}(f_{p_2}(\ldots f_{p_K}(1) \ldots))\) の最大値を取ります。
簡単のため、\(\frac{A_1-1}{B_1} \geq \frac{A_2-1}{B_2} \geq \ldots \frac{A_N-1}{B_N}\) を満たすように \((A,B)\) をソートしたとします。
この時以下の問題に言い換えることができます。
\(1\) 以上 \(N\) 以下の整数からなり、\(p_1 <p_2< \ldots <p_N\) を満たす \(p\) を取ったとき、\(f_{p_1}(f_{p_2}(\ldots f_{p_K}(1) \ldots))\) としてありえる最大値を求めよ。
この問題は以下のような dp で求めることができます。
\(dp_{i,j}=( p_K,p_{K-1},\ldots p_{K-i} \)まで定めて、\(p_{K-i} \geq j\) を満たすとき、\(f_{p_{K-i}}(\ldots f_{p_K}(1)\ldots)\) としてありえる最大値)
具体的に、各遷移は \(dp_{i,j}=\text{max}(dp_{i,j+1},A_j \times dp_{i-1,j+1}+B_j)\) となり、求めるべき答えは \(dp_{K,1}\) です。
状態数は \(O(NK)\) で遷移は \(O(1)\) であるため、全体で時間計算量 \(O(NK)\) でこの問題を解くことができます。
実装例(C++):
#include<bits/stdc++.h>
using namespace std;
int main() {
int n, k;
cin >> n >> k;
vector<int> a(n), b(n);
for (int i = 0; i < n; ++i) {
cin >> a[i] >> b[i];
}
vector<int> ord(n);
for (int i = 0; i < n; ++i)
ord[i] = i;
sort(ord.begin(), ord.end(), [&](int i, int j) { return b[i] * (a[j] - 1) > b[j] * (a[i] - 1); });
vector<long long> dp(k + 1, -1e9);
dp[0] = 1;
for (auto i : ord) {
vector<long long> ndp = dp;
for (int j = 0; j < k; ++j) if(dp[j] != -1e9) {
ndp[j + 1] = max(ndp[j + 1], dp[j] * a[i] + b[i]);
}
dp = move(ndp);
}
cout << dp[k] << "\n";
}
実装例(Python):
n, k = map(int, input().split())
a = []
b = []
for _ in range(n):
ai, bi = map(int, input().split())
a.append(ai)
b.append(bi)
ord_indices = list(range(n))
ord_indices.sort(key=lambda i: (a[i] - 1) / b[i])
dp = [int(-1e9)] * (k + 1)
dp[0] = 1
for i in ord_indices:
ndp = dp[:]
for j in range(k):
if dp[j] > int(-1e9):
ndp[j + 1] = max(ndp[j + 1], dp[j] * a[i] + b[i])
dp = ndp
print(dp[k])
posted:
last update: