Official

C - Filling 3x3 array Editorial by en_translator


Some of you may have thought “this puzzle/mathematics problem can be solved with a computer, can’t it?” Some of you may also experienced succeeded in writing a code that outputs correct answer, or the program does not finish at all, obtaining nothing.
This problem is another chance to solve such puzzles with programming. Maybe this time, a certain number of contestants may have got TLE (Time Limit Exceeded) for large cases. If you couldn’t solve this problem, we would like you to learn how to “speed up your program by appropriate observations.”

(In the explanation below, we let \(\max(h_1,h_2,h_3,w_1,w_2,w_3) = N\).)
The most easiest solution to come up with is a naive bruteforcing. Since each square can contain integers between \(1\) and \(N\), we may actually assign to the squares integers between \(1\) and \(N\) and check if the grid satisfies the conditions.
This can be achieved with \(9\)-fold for loop, but it’s a bit a bad approach; let’s try implementing it with Depth-First Searching (DFS). Roughly speaking, we can write as follows:

int h[3], w[3], a[3][3];
long long ans = 0;
void dfs(int ij) {
  int i = ij / 3, j = ij % 3;
  if (i == 3) { 
    if( /* the grid satsifies the conditions */ ) ans++;
    return;
  }
  for (int x = 1; x <= 30; x++) {
    a[i][j] = x;
    dfs(ij + 1);
  }
}

This solution do yield the right answer, but has a complexity of \(\mathrm{O}(N^9)\), leading to TLE. How can we omit the steps of searching to optimize it?

Let us name each square from \(a\) through \(i\) as follows:

image

One of the property of this problem is that once two of the squares in a row or a column is filled, the other is automatically determined.
For example, if \(h_1 = 10, a = 3, b = 5\), then \(c\) is determined by:

\[c = h_1 - a - b = 10 - 3 - 5 = 2.\]

The original program iterated all the possibilities for the \(3 \times 3\) grid, but using this fact, we can show that we only have to enumerate the top-left \(\bf 2 \times 2\) grid.

We will explain in detail. Suppose that we have determined the values of \(a,b,d\), and \(e\) during the search. Using the fact that the \(i\)-th row has a sum of \(h_i\) for \(i=1,2\), we can determine \(c\) and \(f\) by:

\[\begin{aligned} c &= h_1 - a - b \\ f &= h_2 - d - e \\ \end{aligned} \]

Moreover, now that we have obtained from \(a\) through \(f\), we can determine \(g\), \(h\), and \(i\) by:

\[\begin{aligned} g &= w_1 - a - d \\ h &= w_2 - b - e \\ i &= w_3 - c - f \\ \end{aligned} \]

The only condition we have not used yet is \(g+h+i=h_3\), so it is sufficient to check if the sum of \(g\), \(h\), and \(i\) obtained above is equal to \(h_3\).

Since we have reduced the range of bruteforcing, the time complexity is dropped to \(\mathrm{O}(N^4)\), which runs fast enough under the Constraints. We can implement it with a for-loop or DFS.

  • Sample code using for-loop (C++)
#include <algorithm>
#include <iostream>
using namespace std;
int H[3], W[3], ans = 0;
int main() {
  for (int i = 0; i < 3; i++) cin >> H[i];
  for (int j = 0; j < 3; j++) cin >> W[j];
  for (int a = 1; a <= 30; a++) {
    for (int b = 1; b <= 30; b++) {
      for (int d = 1; d <= 30; d++) {
        for (int e = 1; e <= 30; e++) {
          int c = H[0] - a - b;
          int f = H[1] - d - e;
          int g = W[0] - a - d;
          int h = W[1] - b - e;
          int i = W[2] - c - f;
          if (min({c, f, g, h, i}) > 0 and g + h + i == H[2]) ans++;
        }
      }
    }
  }
  cout << ans << "\n";
}
  • Sample code using DFS (C++)
#include <iostream>
using namespace std;
int h[3], w[3], a[3][3];
long long ans = 0;
void dfs(int ij) {
  int i = ij / 3, j = ij % 3;
  if (i == 3) {
    ans++;
    return;
  }
  if (i == 2) {
    int x = w[j] - a[0][j] - a[1][j];
    if (x <= 0) return;
    a[i][j] = x, dfs(ij + 1);
  } else if (j == 2) {
    int x = h[i] - a[i][0] - a[i][1];
    if (x <= 0) return;
    a[i][j] = x, dfs(ij + 1);
  } else {
    for (int x = 1; x <= 30; x++) a[i][j] = x, dfs(ij + 1);
  }
}
int main() {
  for (int i = 0; i < 3; i++) cin >> h[i];
  for (int j = 0; j < 3; j++) cin >> w[j];
  if (h[0] + h[1] + h[2] != w[0] + w[1] + w[2]) {
    cout << 0 << "\n";
    return 0;
  }
  dfs(0);
  cout << ans << "\n";
}

posted:
last update: