Official

F - Balanced Rectangles Editorial by physics0523


この問題には \(H \times W \le 300000\) という特別な制約がかかっています。
この制約から、 \(\min(H,W) \le \sqrt{300000} \approx 548\) ということが分かります。
\(H \le W\) となるように予め入力を変形しておくことで、 \(O(H^2W)\) といった時間計算量をかけても各パラメータに値を代入して額面 \(1.7 \times 10^8\) という計算量になり、時間制限に間に合いそうなことが分かります。

以降、 \(H \le W\) であるよう入力を変形したものとして考えます。
例えば \(90\) 度回転を用いてもよければ、 \(S'_{i,j} = S_{j,i}\) と変形しても構いません。どちらの変形でも答えは不変です。

まず、長方領域のうち上辺と下辺を固定します。この選び方は \(O(H^2)\) 通りあります。上辺と下辺を固定した時の問題を時間計算量 \(O(W)\) で解ければこの問題に正解できます。

最も上の行を \(u\) 、最も下の行を \(d\) と固定した場合の問題は次の通りです。

\(C_j\) を次の通りに定義する。

  • 上から \(u\) 行目から \(d\) 行目、左から \(1\) 行目から \(W\) 行目を取り出した領域を \(G\) とする。
  • \(C_j\) を ( \(G\) の左から \(j\) 列目に含まれる # の数 ) \(-\) ( \(G\) の左から \(j\) 列目に含まれる . の数 ) と定義する。

このとき、 \(C_l+C_{l+1}+\dots+C_{r}=0\) かつ \(l \le r\) となる整数の組 \((l,r)\) の数が求めたい値である。

この問題はまさに Zero-Sum Ranges です。しかし、この解説の通りに解くと \(O(W \log W)\) となり、全体で \(O(H^2 W \log W)\) 、値を代入して額面 \(1.5 \times 10^{9}\) という計算量となり、よほど定数倍に気を遣った実装でない限り実行時間に間に合いません (逆に、この程度の計算量であれば非常に気を遣った実装をすると正解できる場合もあります。)
では、どのように高速化すればよいでしょうか?

ここで、 \(C_l+C_{l+1}+\dots+C_{r}\) のとりうる値域に注目しましょう。 \(C_i\) の定義より、この値域は \(-HW\) 以上 \(HW\) 以下です。なので、総和が \(0\) となる連続部分列を数える際に mapsort といった \(\log\) がつくデータ構造の代わりに単なる配列を工夫して用いることで、この \(\log\) を排除することができます。
なお、毎回 \(C_i\) を愚直に計算すると計算量が増加するため、 \(u,d\) に関してループを回す際に工夫して \(C_i\) を更新する必要があります。

\(\log\) を排除すると、最初の望み通りに時間計算量 \(O(H^2W)\) でこの問題に正解できます。

実装例 (C++):

#include<bits/stdc++.h>

using namespace std;

vector<string> flip(vector<string> &s){
  int H=s.size(),W=s[0].size();
  vector<string> res(W);
  for(int i=0;i<W;i++){
    for(int j=0;j<H;j++){
      res[i].push_back(s[j][i]);
    }
  }
  return res;
}

int main(){
  int t;
  cin >> t;
  while(t>0){
    t--;
    int H,W;
    cin >> H >> W;
    vector<string> s(H);
    for(auto &nx : s){cin >> nx;}
    if(H>W){s=flip(s);}
    H=s.size();
    W=s[0].size();

    int ofs=H*W;
    vector<int> bk(2*H*W+1,0);
    long long res=0;
    for(int u=0;u<H;u++){
      vector<int> C(W,0);
      for(int d=u;d<H;d++){
        for(int i=0;i<W;i++){
          if(s[d][i]=='#'){C[i]++;}
          else{C[i]--;}
        }
        int h;
        h=0;
        bk[h+ofs]++;
        for(int i=0;i<W;i++){
          h+=C[i];
          res+=bk[h+ofs];
          bk[h+ofs]++;
        }

        // reset bk
        h=0;
        bk[h+ofs]=0;
        for(int i=0;i<W;i++){
          h+=C[i];
          bk[h+ofs]=0;
        }
      }
    }
    cout << res << "\n";
  }
  return 0;
}

posted:
last update: