Official

C - 2026 Editorial by en_translator


First, let us think of a solution ignoring the computational complexity. All we need is to determine if each integer up to \(N\) is good, so we can come up with the following algorithm:

  • Prepare an array \(a\) that manages the answers. Initially, \(a\) is empty.
  • For each \(n=1,2,\dots, N\), do the following:
    • First, set \(c = 0\).
    • For each integer \(x\) with \(0 \lt x \leq \sqrt{n}\), do the following. (We impose the bound \(x \leq \sqrt{n}\) because, if \(x \gt \sqrt{n}\), there is no \(y\) such that \(x^2+y^2=n\).)
      • Find the non-negative value \(y\) such that \(x^2 + y^2 = n\). (This can be done by computing sqrt(n - x * x) .)
      • If \(y\) is a positive integer and\(x \lt y\), increment \(c\) by \(1\).
    • If the final value of \(c\) is \(1\), then the current integer turns out to be good, so push \(n\) to \(a\).
  • The final array \(a\) stores the answers.

However, this algorithm computes \(y\) \(\mathrm{O}(N \sqrt{N})\) times, so the implementation will lead to a TLE (Time Limit Exceeded) verdict.

Here, notice that the pairs \((x, y)\) are inspected within the range \(x^2 + y^2 \leq N\). This suggests an idea to enumerate \((x, y)\) without setting a fixed \(n\). That is, we can construct the following algorithm:

  • First, let \(c = (c_1, c_2, \dots, c_N)\) be an array filled with zeros.
  • For \(x = 1, 2, \dots\), while \(x^2 \leq N\), do the following:
    • For \(y = x+1, x+2, \dots\), while \(x^2 + y^2 \leq N\), do the following:
      • Increment \(c_{x^2 + y^2}\) by one.
  • Finally, the good integers are those \(n\) with \(c_n = 1\), so enumerate such \(n\).

What is the complexity of this algorithm? Since \(x\) and \(y\) range in \(x^2 + y^2 \leq N\), we have \(x, y \leq \sqrt{N}\). Thus, the number of the pairs \((x, y)\) inspected is \(\mathrm{O}(\sqrt{N} \times \sqrt{N}) = \mathrm{O}(N)\). Thus, the problem has been solved in \(\mathrm{O}(N)\) time, which is fast enough.

  • Sample code (C++)
#include <iostream>
#include <vector>
using namespace std;

int main() {
  int N;
  cin >> N;
  vector<int> c(N + 1);
  for (int x = 1; x * x <= N; x++) {
    for (int y = x + 1; x * x + y * y <= N; y++) {
      c[x * x + y * y]++;
    }
  }
  vector<int> ans;
  for (int n = 1; n <= N; n++) {
    if (c[n] == 1) ans.push_back(n);
  }
  cout << ans.size() << "\n";
  for (int i = 0; i < (int)ans.size(); i++) {
    cout << ans[i] << " \n]"[i + 1 == (int)ans.size()];
  }
}

posted:
last update: