Official

B - At Most 3 (Judge ver.) Editorial by en_translator


First of all, we can rephrase the Problem Statement with a programming-style procedure like as follows. If we can process this procedure fast enough, the problem can be solved.

  • Prepare an array of flags flag that manages for every integer between \(1\) and \(W\) (inclusive) if “\(n\) is a good integer”. Initially, the elements of flag are initialized with false.
  • Inspect every set of weights of size at most \(3\). For each set, find the sum \(w\) of mass of the weights. If \(w\) does not exceed \(W\), let flag[w] be true.
  • Finally, the answer is the number of flags whose values are true.

The most difficult portion in the procedure above is to “inspect every set of weights of size at most \(3\).” If you take the wrong approach for this search, your program’s complexity will be massive, leading to TLE (Time Limit Exceeded).
The key is to rephrase “of size at most \(3\)” with “of size \(1\), \(2\), or \(3\).” Since we can enumerate “the sets of sizes \(x\)” with the following implementation using for-loop in a total of \(\mathrm{O}(N^x)\) time, this problem is boiled down to a problem that can be implemented with for-loops by checking the case where the size of the set is \(1\), \(2\), and \(3\) individually.

// Here is a example of for-loop for x = 2
  for(int i = 0; i < N; i++) {
    for(int j = i + 1; j < N; j++) {
      // Process pair (i, j) of different weights
    }
  }

The time complexity for the flag array is \(\mathrm{O}(W)\) and for the for-loop is \(\mathrm{O}(N+N^2+N^3) = \mathrm{O}(N^3)\), for a total of \(\mathrm{O}(W + N^3)\), which runs fast enough.

Sample code in C++ follows.

#include <iostream>
#include <vector>
using namespace std;

int main() {
  int N, W;
  cin >> N >> W;
  vector<int> A(N);
  for (auto& x : A) cin >> x;
  
  vector<int> flag(W + 1);

  for (int i = 0; i < N; i++) {
    int s = A[i];
    if (s <= W) flag[s] = 1;
  }

  for (int i = 0; i < N; i++) {
    for (int j = i + 1; j < N; j++) {
      int s = A[i] + A[j];
      if (s <= W) flag[s] = 1;
    }
  }
  
  for (int i = 0; i < N; i++) {
    for (int j = i + 1; j < N; j++) {
      for (int k = j + 1; k < N; k++) {
        int s = A[i] + A[j] + A[k];
        if (s <= W) flag[s] = 1;
      }
    }
  }
  
  int ans = 0;
  for (int i = 1; i <= W; i++) ans += flag[i];
  cout << ans << endl;
}

posted:
last update: