公式

G - 221 Substring 解説 by sheyasutaka


シンプルな形への帰着

\(S\) の連長圧縮を \((v_1, m_1), \cdots, (v_n, m_n)\) とおきます(ここで \((v_i, m_i)\)\(v_i\)\(m_i\) 個並ぶ列を表す).

221 数列であるような連続部分列が \((v_l, v_l), \cdots, (v_r, v_r)\) の形の連長圧縮で抜き出せるとすると,これは以下を満たします.

  • \(v_l \leq m_l\), \(v_r \leq m_r\)
  • \(v_i = m_i\) (\(l < i < r\))

\(S\) の連長圧縮の各組を以下の列で置き換え,つなげてできる列を \(T\) とおきます.

  • \(v_i > m_i\) のとき \((0)\)
  • \(v_i = m_i\) のとき \((v_i)\)
  • \(v_i < m_i\) のとき \((v_i, 0, v_i)\)

このとき,以下の \(2\) つは添字列 \((l, \cdots, r)\) として一対一対応します.

  • \(T\) の連続部分列であって \(0\) を含まないもの \((v_l, \cdots, v_r)\)
  • \(S\) の連続部分列であって 221 列であるもの \((v_l, v_l), \cdots, (v_r, v_r)\)

よって,\(T\) の連続部分列であって \(0\) を含まないものの種類数を数える問題に帰着されます.

帰着後の問題の考察

\(T\)\(l\) 文字目から \(r\) 文字目までの連続部分列を \(T[l\ :\ r]\) と表記します.\(T\) の Suffix Array を \(T[p_1\ :\ |T|] < \cdots < T[p_{|T|}\ :\ |T|]\) とおき,\(T[p_{i-1}\ :\ |T|]\)\(T[p_{i}\ :\ |T|]\) の共通接頭辞の長さを \(\mathrm{lcp}[i]\) とおきます.

\(T[p_i\ :\ |T|]\) (\(i < k\)) の接頭辞としては現れず,\(T[p_k\ :\ |T|]\) の接頭辞としては現れる,\(0\) を含まない連続部分列の個数を考えます.これは,\(p_k\) 文字目を先頭として \(0\) でない文字が続く文字数を \(\mathrm{zerofree}[k]\) とおいたとき,以下で得られます.

  • \(\mathrm{zerofree}[k] \leq \mathrm{lcp}[k]\) のとき,\(0\) 種類.
  • \(\mathrm{zerofree}[k] > \mathrm{lcp}[k]\) のとき,\(\mathrm{zerofree}[k] - \mathrm{lcp}[k]\) 種類.

上述の値の \(k = 1, \cdots, |T|\) に対する総和が求める答えです.

実装方針

有名事実として,文字種類数 \(\sigma\) の文字列に対する Suffix Array は \(O(N + \sigma)\) 時間で,Longest Common Prefix は \(O(N)\) 時間で求まります.

また,\(p_k\) 文字目を先頭として \(0\) でない文字が続く文字数 \(\mathrm{zerofree}[k]\) は,後ろから求めることで全体の値を \(O(N)\) 時間で求めることができます.

以上を実装することでこの問題を高速に解くことができます.時間計算量は \(O(N + \sigma)\) です.

実装例 (C++)

#include <iostream>
using std::cin;
using std::cout;
#include <vector>
using std::vector;
using std::pair;
using std::make_pair;
using std::min;
using std::max;
#include <set>
using std::set;

#include <atcoder/string>
using atcoder::suffix_array;
using atcoder::lcp_array;

typedef long long int ll;

const int SIGMA = 9;

ll n;
vector<int> a;

struct Item {
	ll len;
	int val;
	ll idx;
};

// run-length list of (length, value)
vector<Item> runlength (const vector<int> &a) {
	vector<Item> lens;
	ll i = 0;
	while (i < (ll)a.size()) {
		ll j = i;
		while (j < (ll)a.size() && a[j] == a[i]) j++;

		lens.push_back({len: j-i, val: a[i], idx: i});

		i = j;
	}

	return lens;
}

// count substrings of s that doesn't contain 0 (ignoring position difference)
ll count_zerofrees (const vector<int> &s) {
	// a substring of s that doesn't contain 0 <-1to1-> original 221substr
	const ll m = s.size();
	const vector<int> sa = suffix_array(s, SIGMA);
	const vector<int> lcp = lcp_array(s, sa);

	// max_zerofree[i] := argmax_len [ s[i, i+len) doesn't contain 0 ]
	vector<ll> max_zerofree(m+1);
	max_zerofree[m] = 0;
	for (ll i = m-1; i >= 0; i--) {
		if (s[i] == 0) {
			max_zerofree[i] = 0;
		} else {
			max_zerofree[i] = max_zerofree[i+1] + 1;
		}
	}

	ll sum = 0;
	for (ll i = 0; i < m; i++) {
		ll l = ((i == 0) ? 0 : lcp[i-1]);
		ll r = max_zerofree[sa[i]];
		
		// l < len <= r
		sum += max(0LL, r - l + 0);
	}

	return sum;
}

void solve () {
	// run-length list of (length, value)
	const vector<Item> lens = runlength(a);

	// maximal [l, r) where [lens[l], lens[r]) satisfies either:
	// 1. lens[k].val == lens[k].len, or
	// 2. lens[k].val <  lens[k].len AND k is in {l, r-1}
	vector<pair<ll, ll> > maximals;
	{
		ll idxlen = 0;
		while (idxlen < (ll)lens.size()) {
			if (lens[idxlen].val > lens[idxlen].len) {
				idxlen++;
				continue;
			}

			ll jlen = idxlen + 1;
			ll jsee = jlen;
			while (jlen < (ll)lens.size()) {
				if (lens[jlen].val > lens[jlen].len) {
					jlen += 0;
					jsee += 0;
					break;
				} else if (lens[jlen].val < lens[jlen].len) {
					jlen += 0;
					jsee += 1;
					break;
				} else {
					jlen += 1;
					jsee += 1;
					continue;
				}
			}

			maximals.push_back({idxlen, jsee});
			idxlen = jlen;
		}
	}

	// [0] ++ {[lens[li].val, lens[ri].val)}.join([0]) ++ [0]
	vector<int> s = {0};
	for (const auto &[li, ri] : maximals) {
		for (int i = li; i < ri; i++) {
			s.push_back(lens[i].val);
		}
		s.push_back(0);
	}

	ll ans = count_zerofrees(s);
	cout << ans << "\n";	
}

int main (void) {
	std::cin.tie(nullptr);
	std::ios_base::sync_with_stdio(false);

	cin >> n;
	
	a.resize(n);
	for (ll i = 0; i < n; i++) {
		cin >> a[i];
		assert(a[i] <= SIGMA);
	}

	solve();

	return 0;
}

投稿日時:
最終更新: