Official

I - Maximum Composition Editorial by MtSaka


数列 \(p\) に使われる \(K\) 個の整数を定めたとき、最大値を取るような \(p\) にはどのような性質があるかを考えます。この性質を考える時に、以下の事実が重要になります。

\(f_i(f_j(x))\)\(f_j(f_i(x))\) の大小は \(x\) によらない。また、この大小は\(\frac{A_i-1}{B_i}\)\(\frac{A_j-1}{B_j}\) の大小による。\(\frac{A_i-1}{B_i}>\frac{A_j-1}{B_j}\) ならば \(f_i(f_j(x))>f_j(f_i(x))\) である。

この性質より、\(p\) に用いられている \(K\) 個の整数を定めたとき、\(\frac{A_{p_1}-1}{B_{p_1}} \geq \frac{A_{p_2}-1}{B_{p_2}} \geq \ldots \geq \frac{A_{p_K}-1}{B_{p_K}}\) を満たす \(p\)\(f_{p_1}(f_{p_2}(\ldots f_{p_K}(1) \ldots))\) の最大値を取ります。

簡単のため、\(\frac{A_1-1}{B_1} \geq \frac{A_2-1}{B_2} \geq \ldots \frac{A_N-1}{B_N}\) を満たすように \((A,B)\) をソートしたとします。

この時以下の問題に言い換えることができます。

\(1\) 以上 \(N\) 以下の整数からなり、\(p_1 <p_2< \ldots <p_N\) を満たす \(p\) を取ったとき、\(f_{p_1}(f_{p_2}(\ldots f_{p_K}(1) \ldots))\) としてありえる最大値を求めよ。

この問題は以下のような dp で求めることができます。

\(dp_{i,j}=( p_K,p_{K-1},\ldots p_{K-i} \)まで定めて、\(p_{K-i} \geq j\) を満たすとき、\(f_{p_{K-i}}(\ldots f_{p_K}(1)\ldots)\) としてありえる最大値)

具体的に、各遷移は \(dp_{i,j}=\text{max}(dp_{i,j+1},A_j \times dp_{i-1,j+1}+B_j)\) となり、求めるべき答えは \(dp_{K,1}\) です。

状態数は \(O(NK)\) で遷移は \(O(1)\) であるため、全体で時間計算量 \(O(NK)\) でこの問題を解くことができます。

実装例(C++):

#include<bits/stdc++.h>
using namespace std;

int main() {
    int n, k;
    cin >> n >> k;
    vector<int> a(n), b(n);
    for (int i = 0; i < n; ++i) {
        cin >> a[i] >> b[i];
    }
    vector<int> ord(n);
    for (int i = 0; i < n; ++i)
        ord[i] = i;
    sort(ord.begin(), ord.end(), [&](int i, int j) { return b[i] * (a[j] - 1) > b[j] * (a[i] - 1); });
    vector<long long> dp(k + 1, -1e9);
    dp[0] = 1;
    for (auto i : ord) {
        vector<long long> ndp = dp;
        for (int j = 0; j < k; ++j) if(dp[j] != -1e9) {
            ndp[j + 1] = max(ndp[j + 1], dp[j] * a[i] + b[i]);
        }
        dp = move(ndp);
    }
    cout << dp[k] << "\n";
}

実装例(Python):

n, k = map(int, input().split())
a = []
b = []
for _ in range(n):
    ai, bi = map(int, input().split())
    a.append(ai)
    b.append(bi)

ord_indices = list(range(n))
ord_indices.sort(key=lambda i: (a[i] - 1) / b[i])

dp = [int(-1e9)] * (k + 1)
dp[0] = 1

for i in ord_indices:
    ndp = dp[:]
    for j in range(k):
        if dp[j] > int(-1e9):
            ndp[j + 1] = max(ndp[j + 1], dp[j] * a[i] + b[i])
    dp = ndp

print(dp[k])

posted:
last update: