Contest Duration: - (local time) (100 minutes) Back to Home

## B - At Most 3 (Judge ver.) Editorial by en_translator

First of all, we can rephrase the Problem Statement with a programming-style procedure like as follows. If we can process this procedure fast enough, the problem can be solved.

• Prepare an array of flags flag that manages for every integer between $$1$$ and $$W$$ (inclusive) if “$$n$$ is a good integer”. Initially, the elements of flag are initialized with false.
• Inspect every set of weights of size at most $$3$$. For each set, find the sum $$w$$ of mass of the weights. If $$w$$ does not exceed $$W$$, let flag[w] be true.
• Finally, the answer is the number of flags whose values are true.

The most difficult portion in the procedure above is to “inspect every set of weights of size at most $$3$$.” If you take the wrong approach for this search, your program’s complexity will be massive, leading to TLE (Time Limit Exceeded).
The key is to rephrase “of size at most $$3$$” with “of size $$1$$, $$2$$, or $$3$$.” Since we can enumerate “the sets of sizes $$x$$” with the following implementation using for-loop in a total of $$\mathrm{O}(N^x)$$ time, this problem is boiled down to a problem that can be implemented with for-loops by checking the case where the size of the set is $$1$$, $$2$$, and $$3$$ individually.

// Here is a example of for-loop for x = 2
for(int i = 0; i < N; i++) {
for(int j = i + 1; j < N; j++) {
// Process pair (i, j) of different weights
}
}


The time complexity for the flag array is $$\mathrm{O}(W)$$ and for the for-loop is $$\mathrm{O}(N+N^2+N^3) = \mathrm{O}(N^3)$$, for a total of $$\mathrm{O}(W + N^3)$$, which runs fast enough.

Sample code in C++ follows.

#include <iostream>
#include <vector>
using namespace std;

int main() {
int N, W;
cin >> N >> W;
vector<int> A(N);
for (auto& x : A) cin >> x;

vector<int> flag(W + 1);

for (int i = 0; i < N; i++) {
int s = A[i];
if (s <= W) flag[s] = 1;
}

for (int i = 0; i < N; i++) {
for (int j = i + 1; j < N; j++) {
int s = A[i] + A[j];
if (s <= W) flag[s] = 1;
}
}

for (int i = 0; i < N; i++) {
for (int j = i + 1; j < N; j++) {
for (int k = j + 1; k < N; k++) {
int s = A[i] + A[j] + A[k];
if (s <= W) flag[s] = 1;
}
}
}

int ans = 0;
for (int i = 1; i <= W; i++) ans += flag[i];
cout << ans << endl;
}


posted:
last update: