F - XOR on Grid Path Editorial by cirno3153


\(\mathrm{expected} \ O(2^N)\) で解く方法を紹介します。

公式解説と同じように、半分全列挙で解きます。

ここで、対角線まで移動する方法を考えます。これは、合計で \(N-1\) 回だけ右か下に移動する方法になります。

すなわち、長さ \(N-1\) のbit列を考えると、その各bitの \(0/1\) を右/下に対応させることで一対一に対応します。

これで、単純なbit全探索で対角線への移動を実装できるようになりました。 これを普通にbit全探索すると \(O(N 2^N)\) ですが、深さ優先探索を用いると \(O(2^N)\) になります。

深さ優先探索を用いると計算量が落ちる理由

深さ優先探索における探索は、高さ \(N\) の完全二分木と対応付けられます。

各頂点に対して値を割り振ります。

根は \(0\) とします。 それ以外の頂点は、親の値を \(x\) として、左側の子なら \(2x\) 、右側の子なら \(2x+1\) とします。

この時、葉の値は長さ \(N-1\) のbit列と対応しており、この完全二分木は頂点数 \(2^N\) となります。

この木に対して深さ優先探索をすると、各辺に対して潜る時と結果を返すために登る時の2回だけ通ることから頂点数の二倍の計算量で抑えられます。

また、各頂点において行うべき操作は対応する座標の値をxorすることであり、これは \(O(1)\) で行えます。

従って、全体の計算量を \(O(2^N)\) で抑えられます。

最後に半分全列挙したものを突き合わせて答えを得る部分ですが、これは片方を走査した時、もう片方が存在するかはハッシュテーブルを用いると \(\mathrm{expected} \ O(1)\) で判定できます。

C++の実装例

#include<iostream>
#include<algorithm>
#pragma GCC target("avx2")
#pragma GCC optimize("O3")
#pragma GCC optimize("unroll-loops")
using namespace std;
using uint=unsigned int;
using ull=unsigned long long;
struct hash_map_open_addressing {
  #define LG 16
  #define MASK 0xFFFFFFFF
  ull store[1 << LG];
  bool is_occupied[1 << LG] = {};

  constexpr inline void next_pos(uint &pos) const {
    ++ pos &= (1 << LG) - 1;
  }

  static constexpr ull kek = 11995408973635179863ull;
  constexpr inline uint hsh(const uint& key) const {
    return (key * kek) >> (64 - LG);
  }

  inline uint operator[](const uint& key) const {
    uint pos = hsh(key);
    for (; is_occupied[pos]; next_pos(pos)) {
      if ((store[pos] & MASK) == key) return store[pos] >> 32;
    }
    return 0;
  }

  inline void increment(const uint& key) {
    uint pos = hsh(key);
    for (; is_occupied[pos]; next_pos(pos)) {
      if ((store[pos] & MASK) == key) {
        store[pos] += 1ull << 32;
        return;
      }
    }
    store[pos] = 1ull << 32 | key;
    is_occupied[pos] = true;
  }
};

int main(){
  int n;
  cin >> n;
  #define MAX_N 20
  #define W 22
  int a[W * W];
  for(int y = 1;y <= n;++y) for(int x = 1;x <= n;++x) cin >> a[y * W + x];
  hash_map_open_addressing map[MAX_N];
  ull ans = 0;
  int diagonal1 = n - 2;
  int diagonal2 = n - 1;
  int stack = 0;
  int stackIdx = diagonal1;
  int p = W + 1;
  uint xorS = a[p];
  int table[W * W];
  for (int i = 1;i < n;++ i) table[(n - i) * W + i] = i - 1;
  for (int cnt = 1 << diagonal1;stack != cnt;++ stack) {
    while(stackIdx != 0) {
      xorS ^= a[++p]; // 降りるパート
      -- stackIdx;
    }
    map[table[p]].increment(xorS);
    while(stack >> stackIdx & 1) { // 登るパート
      xorS ^= a[p];
      p -= W;
      ++ stackIdx;
    }
    xorS ^= a[p--]; // 子を変更するパート
    xorS ^= a[p += W];
  } // ここで盤面をひっくり返す
  for (int y = 1;y <= n >> 1;++ y) for (int x = 1;x <= n;++ x) swap(a[y * W + x], a[(n + 1 - y) * W + n + 1 - x]);
  if (n & 1) for (int x = 1;x <= n >> 1;++ x) swap(a[((n + 1) >> 1) * W + x], a[((n + 1) >> 1) * W + n + 1 - x]);
  stack = 0;
  stackIdx = diagonal2;
  p = W + 1;
  uint xorT = a[p];
  for (int i = 1;i <= n;++ i) table[i * W + (n - i + 1)] = i - 1;
  for (int cnt = 1 << diagonal2;stack != cnt;++ stack) {
    while(stackIdx != 0) {
      xorT ^= a[++p]; // 降りるパート
      -- stackIdx;
    }
    ans += (p == W + n ? 0 : map[table[p - W + 1]][xorT]) + map[table[p]][xorT];
    while(stack >> stackIdx & 1) { // 登るパート
      xorT ^= a[p];
      p -= W;
      ++ stackIdx;
    }
    xorT ^= a[p--]; // 子を変更するパート
    xorT ^= a[p += W];
  }
  cout << ans << endl;
}

Javaの実装例

import java.util.*;
import static java.lang.System.out;
public class Main {
  public static void main(String[] args) {
    try (Scanner sc = new Scanner(System.in)) {
      int N = sc.nextInt();
      int[][] a = new int[N][N];
      for (int y = 0;y < N;++ y) for (int x = 0;x < N;++ x) a[y][x] = sc.nextInt();
      ArrayList<HashMap<Integer, Long>> S = new ArrayList<>(N), T = new ArrayList<>(N); // Sは始点からN-1回移動したときの値の集合、Tは終点からN-1回移動したときの値の集合
      for (int i = 0;i < N;++ i) {
        S.add(new HashMap<>());
        T.add(new HashMap<>());
      }
      dfs(0, 0, N, a, S, T, a[0][0], a[N - 1][N - 1]);
      long ans = 0;
      for (int i = 0;i < N;++ i) {
        for (Map.Entry<Integer, Long> e : S.get(i).entrySet()) {
          ans += e.getValue() * T.get(i).getOrDefault(e.getKey(), 0L); // 各要素について、対応する要素の個数が答えになる
        }
      }
      out.println(ans);
    }
  }
  
  static void dfs(int x, int y, int N, int[][] a, ArrayList<HashMap<Integer, Long>> S, ArrayList<HashMap<Integer, Long>> T, int xorS, int xorT) {
    if (x + y == N - 1) { // 対角線に辿り着いた
      S.get(x).merge(xorS, 1L, (l, r) -> l + r);
      T.get(y).merge(xorT ^ a[N - y - 1][N - x - 1], 1L, (l, r) -> l + r);
      return;
    }
    dfs(x + 1, y, N, a, S, T, xorS ^ a[y][x + 1], xorT ^ a[N - y - 1][N - x - 2]); // 横に動く
    dfs(x, y + 1, N, a, S, T, xorS ^ a[y + 1][x], xorT ^ a[N - y - 2][N - x - 1]); // 縦に動く
  }
}

posted:
last update: