E - Maximum Glutton Editorial by en_translator
Notice that the total sweetness and saltiness of all but last dishes that he eats do not exceed \(X\) and \(Y\), respectively (*).
Let \(x\) be the maximum size of a set of dishes \(S\) such that the total sweetness does not exceed \(X\) and the total saltiness does not exceed \(Y\). Then, the answer to the original problem can be represented as \(\min(x+1, N)\). (Proof: obvious if \(x=N\). If \(x<N\), he can eat at least \((x+1)\) dishes by eating those in \(S\) first, and if he could eat \((x+2)\) dishes, one can take \(S\) of size \((x+1)\), which is a contradiction. Thus, he can eat at most \((x+1)\) dishes.)
Now let us try to find the maximum number of dishes that can chosen so that the total sweetness and saltiness do not exceed \(X\) and \(Y\), respectively. One of the simplest approaches would be defining the following DP (Dynamic Programming), but the complexity amount to \(O(NXY)\), which is to heavy:
- \(dp_{i,j,k}=(\)the maximum number of dishes that can be chosen from dishes \(1,2,\dots,i\) so that the total sweetness is exactly \(j\) and the total saltiness is exactly \(k\)).
Instead, noticing that the value of \(N\) is fairly smaller than those of \(X\) and \(Y\) in this problem, we take the following typical approach: swapping the key and value of DP. Specifically, bring the number of dishes to the key and total saltiness to the value to define:
- \(dp_{i,j,k}=(\)the minimum total saltiness when choosing exactly \(k\) dishes from dishes \(1,2,\dots,i\) so that the total sweetness is exactly \(j\)).
This way, the DP table can be filled in a total of \(O(N^2X)\) time. All that left is to find the maximum \(k\) such that there exists \(j\) with \(dp'_{N,j,k} \leq Y\).
Sample code (C++):
#include <bits/stdc++.h>
using namespace std;
void chmin(int &a, int b) { a = min(a, b); }
int main() {
int n, x, y;
cin >> n >> x >> y;
vector<int> a(n), b(n);
for (int i = 0; i < n; i++) {
cin >> a[i] >> b[i];
}
vector dp(n + 1, vector(n + 1, vector<int>(x + 1, 1e9)));
dp[0][0][0] = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j <= i; j++) {
for (int k = 0; k <= x; k++) {
chmin(dp[i + 1][j][k], dp[i][j][k]);
if (k + a[i] <= x) {
chmin(dp[i + 1][j + 1][k + a[i]], dp[i][j][k] + b[i]);
}
}
}
}
for (int i = n; i >= 0; i--) {
for (int j = 0; j <= x; j++) {
if (dp[n][i][j] <= y) {
cout << min(i + 1, n) << endl;
return 0;
}
}
}
}
posted:
last update: