公式

D - Cuboid Sum Query 解説 by en_translator


This problem is an exercise of so-called cumulative sums. We assume the knowledge of cumulative sums for a sequence (one-dimensional cumulative sums).

One-dimensional cumulative sums are used to solve a problem like this:

Given a length-\(N\) sequence \(A\), process \(Q\) queries.

In the \(i\)-th query, given an integer pair \((L_i,R_i)\) satisfying \(1 \leq L_i \leq R_i \leq N\), find \(\sum_{k=L_i}^{R_i}A_k\).

A solution to this problem is as follows: Define a length-\((N+1)\) sequence \(S=(S_0,S_1,\ldots S_N)\) by \(S_0=0,S_i=\sum_{k=1}^{i-1}A_k\), and use \(\sum_{k=L_i}^{R_i}A_k=S_{R_i+1}-S_{L_i}\) to answer the query fast.

The sequence \(S\) in this solution is generally called the cumulative sums of \(A\).

The original problem is an extension of this one-dimensional cumulative sum problem into three dimensions.

Just as in the definition above, consider \(S\) such that \(S_{i,j,k}=\sum_{x=1}^{i-1} \sum_{y=1}^{j-1}\sum_{z=1}^{k-1}A_{x,y,z}\) for \(1 \leq i,j,k \leq N+1\). (If \(i\), \(j\), or \(k\) is \(0\), we define \(S_{i,k,k}=0\).)

Then it holds that \(\sum_{x=Lx_i}^{Rx_i} \sum_{y=Ly_i}^{Ry_i} \sum_{z=Lz_i}^{Rz_i} A_{x,y,z}=S_{Rx_i+1,Ry_i+1,Rz_i+1}-S_{Lx_i,Rx_i+1,Rz_i+1}-S_{Rx_i+1,Ly_i,Rz_i+1}-S_{Rx_i+1,Ry_i+1,Lz_i}+S_{Lx_i,Ly_i,Rz_i+1}+S_{Lx_i,Ry_i+1,Lz_i}+S_{Rx_i+1,Ly_i,Lz_i}-S_{Lx_i,Ly_i,Lz_i}\).

Therefore, if one can compute \(S\) fast, then each query can be answered in a constant time.

In fact, \(S\) can be found fast by the property that \(S_{i,j,k}=S_{i-1,j,k}+S_{i,j-1,k}+S_{i,j,k-1}-S_{i-1,j-1,k}-S_{i-1,j,k-1}-S_{i,j-1,k-1}+S_{i-1,j-1,k-1}+A_{i-1,j-1,k-1}\). The time complexity is \(O(N^3+Q)\).

For more details, please refer to the sample code.

Sample code (C++):

#include <bits/stdc++.h>
using namespace std;

int main() {
    int n;
    cin >> n;
    vector a(n, vector(n, vector(n, 0)));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            for (int k = 0; k < n; ++k) {
                cin >> a[i][j][k];
            }
        }
    }
    vector sum(n + 1, vector(n + 1, vector(n + 1, 0LL)));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            for (int k = 0; k < n; ++k) {
                sum[i + 1][j + 1][k + 1] =
                    sum[i][j + 1][k + 1] + sum[i + 1][j][k + 1] +
                    sum[i + 1][j + 1][k] - sum[i][j][k + 1] - sum[i][j + 1][k] -
                    sum[i + 1][j][k] + sum[i][j][k] + a[i][j][k];
            }
        }
    }
    int q;
    cin >> q;
    for (int i = 0; i < q; ++i) {
        int lx, rx, ly, ry, lz, rz;
        cin >> lx >> rx >> ly >> ry >> lz >> rz;
        lx--, ly--, lz--;
        cout << sum[rx][ry][rz] - sum[lx][ry][rz] - sum[rx][ly][rz] -
                    sum[rx][ry][lz] + sum[lx][ly][rz] + sum[lx][ry][lz] +
                    sum[rx][ly][lz] - sum[lx][ly][lz]
             << "\n";
    }
}

Sample code (Python):

n = int(input())

a = [[[0] * n for _ in range(n)] for _ in range(n)]
for i in range(n):
    for j in range(n):
        a[i][j]=list(map(int,input().split()))

sum = [[[0] * (n + 1) for _ in range(n + 1)] for _ in range(n + 1)]

for i in range(n):
    for j in range(n):
        for k in range(n):
            sum[i + 1][j + 1][k + 1] = (sum[i][j+1][k+1]
                                        + sum[i+1][j][k+1]
                                        + sum[i+1][j+1][k]
                                        - sum[i][j][k+1]
                                        - sum[i][j+1][k]
                                        - sum[i+1][j][k]
                                        + sum[i][j][k]
                                        + a[i][j][k])

q = int(input())
for _ in range(q):
    lx, rx, ly, ry, lz, rz = map(int, input().split())
    lx -= 1
    ly -= 1
    lz -= 1

    result = (sum[rx][ry][rz]
                - sum[lx][ry][rz]
                - sum[rx][ly][rz]
                - sum[rx][ry][lz]
                + sum[lx][ly][rz]
                + sum[lx][ry][lz]
                + sum[rx][ly][lz]
                - sum[lx][ly][lz])
    print(result)

投稿日時:
最終更新: