Official

D - ABA Editorial by cn449


\(N \coloneqq |S|\)、出現しうる文字種数を \(\sigma\) とおきます。本問題では \(\sigma = 26\) です。

\(S_i, S_j, S_k\) をこの順に結合して得られる文字列が回文であることと \(S_i = S_k\) であることは同値です。

言い換え後の数え上げを行う方法は複数考えられますが、ここでは \(2\) 通りの方針を紹介します。

(1) \(j\) に注目する

\(j\) を固定して数え上げることを考えます。このとき、条件を満たす \((i, k)\) の個数は \(S_i = S_k\) であって \(1 \leq i < j\) かつ \(j < k \leq N\) を満たす組の個数です。

\(c \coloneqq S_i\) を固定すると(index ではなく文字を固定しているため高々 \(\sigma\) 通りを調べればよいことに注意してください)、 求めるべきは \(1 \leq i < j\) であって \(S_i = c\) を満たすものの個数および \(j < k \leq N\) であって \(S_k = c\) を満たすものの個数です。

これは累積和の要領で事前に各 \(i, c\) に対し \(1, 2, \ldots , i\) 文字目の \(c\) の個数を計算しておくことにより定数時間で取得でき、全体として \(O(\sigma N)\) 時間で計算できます。

また、求めるべき「\(1 \leq i < j\) であって \(S_i = c\) を満たすものの個数および \(j < k \leq N\) であって \(S_k = c\) を満たすものの個数」が \(j\)\(j + 1\) でほとんど変化しないことを利用すると、適切に差分更新を行いながら計算することで \(O(\sigma N)\) 時間から \(O(N)\) 時間への改善が可能です。

(2) \(i, k\) に注目する

\(S_i = S_k\) のとき、\(j\)\(i < j < k\) を満たすものであればすべて条件を満たすので、\(j\)\(k - i - 1\) 通りの選び方があります。したがって、答えは \(S_i = S_k\) を満たす \((i, k)\) の組に対する \(k - i - 1\) の和です。

したがって、\(S_i = c\) を満たす \(i\) を単調増加に並べた列を \(X_c\) とすると、答えは \(c =\) A, B, \(\ldots\), Z に対して \(\displaystyle\sum_{s = 1}^{|X_c|}\displaystyle\sum_{t = s + 1}^{|X_c|} (X_{c, t} - X_{c, s} - 1)\) を足し合わせたものです。

これは各 \(X_{c, s}\) ごとの寄与を考える、あるいは累積和 \(\displaystyle\sum_{s = 1}^{s'}X_{c,s}\) を計算しながら求めるなどの方法で \(O(|X_c|)\) 時間で値を求めることができます。したがってこの計算を \(c =\) A, B, \(\ldots\), Z に対して行うことにより全体として \(O(N)\) 時間で答えを求めることができます。

実装例(1)

#include <bits/stdc++.h>
using namespace std;

#define rep(i, n) for (ll i = 0; i < (n); i++)
using ll = long long;

int main() {
	string s;
	cin >> s;
	int n = s.size();
	vector<vector<int>> sum(26, vector<int>(n + 1));
	rep(i, n) {
		rep(j, 26) {
			sum[j][i + 1] = sum[j][i];
		}
		sum[s[i] - 'A'][i + 1]++;
	}
	ll ans = 0;
	for (int i = 1; i < n - 1; i++) {
		rep(j, 26) {
			ll l = sum[j][i];
			ll r = sum[j][n] - sum[j][i + 1];
			ans += l * r;
		}
	}
	cout << ans << '\n';
}

実装例(2)

#include <bits/stdc++.h>
using namespace std;

#define rep(i, n) for (ll i = 0; i < (n); i++)
using ll = long long;

int main() {
	string s;
	cin >> s;
	vector<ll> cnt(26), sum(26);
	ll ans = 0;
	int n = s.size();
	rep(i, n) {
		int v = s[i] - 'A';
		ans += (i - 1) * cnt[v] - sum[v];
		cnt[v]++;
		sum[v] += i;
	}
	cout << ans << '\n';
}

posted:
last update: