Official

F - Numbered Checker Editorial by en_translator


There are a lot of approaches to this problem, but the amount of implementation also varies. Let us consider a simple approach.
Note the \(6\)-th query in Sample Input \(1\).
Consider each row in the grid. The grid looks like as follows:

  1  0  3  0
  0  6  0  8
  9  0 11  0
  0 14  0 16
 17  0 19  0

First, we compute the sum for each row. Since the non-zero element for each row forms an arithmetic sequence with common difference \(2\), so we can compute it using the formula of the sum of an arithmetic sequence.
Also, since the grid forms a checker pattern, the leftmost non-zero element is either the \(1\)-st or \(2\)-nd element of the extracted row (if exists). (Same applies for the rightmost one.)

Next, let us consider what we should do after grouping rows. First, see the odd-indexed rows:

  1  0  3  0
  9  0 11  0
 17  0 19  0

For the odd-indexed rows, the following holds:
Let \(x\) be the number of non-zero elements in each row, then the sum is an arithmetic sequence with common difference \(2 \times x \times M\). (This also holds for \(x=0\).)
This is because, for a sequence with a non-zero element, the values in each column forms an arithmetic sequence with common difference \(2 \times M\).

It is basically the same for the even-indexed rows:

  0  6  0  8
  0 14  0 16

Therefore, we can first compute the information for the \(A_i\)-th and \((A_i+1)\)-th row and from the \(C_i\)-th through \(D_i\)-th columns, and then find the sum from \(A_i\)-th through \(B_i\)-th rows.

One key for the implementation is that you need to find the sum of an arithmetic sequence many times, so such a repeated operation can be extracted to a function in order to reduce the implementation.

Sample code (C++):

#include<bits/stdc++.h>
#define mod 998244353
#define inv2 499122177 // inverse of 2

using namespace std;

// A_1 = fir
// A_i = fir + (i-1) * d
// return A_1 + A_2 + ... + A_{num}
long long arithmetic_sum(long long fir,long long d,long long num){
  long long las=(fir+(num-1)*d)%mod;
  long long res=(fir+las)%mod;
  res*=num;res%=mod;
  res*=inv2;res%=mod;
  return res;
}

using pl=pair<long long,long long>;

long long n,m;
pl row_data(long long x,long long l,long long r){
  if((x+l)%2){l++;}
  if((x+r)%2){r--;}
  if(l>r){return {0,0};}

  long long mi=((x-1)*m+l)%mod;
  long long num=1+(r-l)/2;
  long long sum=arithmetic_sum(mi,2,num);

  num*=2;num%=mod;
  num*=m;num%=mod;
  return {sum,num};
}

int main(){
  int q;
  cin >> n >> m >> q;
  for(int i=0;i<q;i++){
    long long a,b,c,d;
    cin >> a >> b >> c >> d;

    pl r1=row_data(a,c,d);
    pl r2=row_data(a+1,c,d);
    long long c1=(b-a+1)/2 + (b-a+1)%2;
    long long c2=(b-a+1)/2;
    long long res=0;
    res+=arithmetic_sum(r1.first,r1.second,c1);res%=mod;
    res+=arithmetic_sum(r2.first,r2.second,c2);res%=mod;
    cout << res << "\n";
  }
  return 0;
}

posted:
last update: