Official

F - Variety Split Hard Editorial by MtSaka


\(X_i=(A_1,A_2,\ldots,A_i)\)\(2\) つの空でない連続する部分列に分割したときの種類数の和の最大値

\(L_i=(A_1,A_2,\ldots,A_i)\) の種類数

\(R_i=(A_i,A_{i+1},\ldots,A_N)\) の種類数

とします。 このとき、答えは\(\max_{2 \leq i \leq N-1}(X_i+R_{i+1})\) となります。

\(L,R\) は前と後ろから順に見ていって、種類数を更新していくことで \(O(N)\) で計算できます。(C問題の解説と同様の定義です。)

\(X\) の計算方法について考えます。

\(dp_{i,j}=(A_1,A_2,\ldots,A_j)\)\((A_{j+1},A_{j+2},\ldots,A_i)\) の種類数の和とします。

このとき、\(X_i=\max(dp_{i})\) です。

また、 \(1 \leq j \leq i-1\) については

\(dp_{i+1,j}=\left \{ \begin{array}{ll} dp_{i,j}+1 & ((A_{j+1},A_{j+2},\ldots,A_{i}) に A_{i+1}が含まれない) \\ dp_{i,j} & ((A_{j+1},A_{j+2},\ldots,A_{i}) に A_{i+1}が含まれる) \end{array} \right\} \)

および、 \(dp_{i+1,i}=L_i+1\) が成り立ちます

\((A_{j+1},A_{j+2},\ldots,A_{i})\)\(A_{i+1}\) が含まれるかは \(j\) について単調であるため、ある整数 \(k\) 以上の\(j\) については \(dp_{i+1,j}=dp_{i,j}+1\) が成り立ち、それ以外では\(dp_{i+1,j}=dp_{i,j}\) になります。つまり、\(dp_{i}\) から \(dp_{i+1}\) への遷移は区間加算で表せます。

これはinline DPの形で、区間加算区間最大値の遅延セグメント木を使うことでメモリを使いまわして\(dp\) を計算できます。

よって、時間計算量 \(O(N\log N)\) で解くことができます。

実装例(C++)

#include <bits/stdc++.h>
#include <atcoder/lazysegtree>
using namespace std;
int op(int a, int b) { return max(a, b); }
int e() { return -1e9; }
int mapping(int f, int x) { return f + x; }
int composition(int f, int g) { return f + g; }
int id() { return 0; }
int main() {
    int n;
    cin >> n;
    vector<int> a(n);
    for (auto& e : a)
        cin >> e;
    vector<int> x(n, 0);
    vector<int> suml(n), sumr(n);
    int now = 0;
    vector<int> vis(n + 1, 0);
    for (int i = 0; i < n; ++i) {
        if (++vis[a[i]] == 1) {
            now++;
        }
        suml[i] = now;
    }
    vis = vector<int>(n + 1, 0);
    now = 0;
    for (int i = n - 1; i >= 0; --i) {
        if (++vis[a[i]] == 1) {
            now++;
        }
        sumr[i] = now;
    }
    atcoder::lazy_segtree<int, op, e, int, mapping, composition, id> dp(n);
    vector<int> last(n + 1, -1);
    for (int i = 0; i < n; ++i) {
        dp.apply(last[a[i]] == -1 ? 0 : last[a[i]], i, 1);
        x[i] = dp.prod(0, i);
        dp.set(i, suml[i]);
        last[a[i]] = i;
    }
    int ans = 0;
    for (int i = 1; i < n - 1; ++i) {
        ans = max(ans, x[i] + sumr[i + 1]);
    }
    cout << ans << endl;
}

posted:
last update: