Official

J - Delete Balls Editorial by se1ka2


まず白いボールが存在しないとき、列のボールをすべて消すことができる必要十分条件について考えてみましょう。

列に含まれている赤いボールの数を \(R\)、青いボールの数を \(B\) とします。

このとき、\(R = r \times k\), \(B = b \times k\) を満たす整数 \(k\) が存在しなければ、明らかにすべてのボールを消すことはできません。

逆に、上を満たす \(k\) が存在すればすべてのボールを消すことができることを数学的帰納法で示しましょう。

\(k = 0\) ならすべてのボールが消えています。

\(k \geq 1\) とします。 列を左から \(r + b\) 個ずつ \(k\) 個に分割します。以降これらを区間と呼びます。

ある区間に赤いボールがちょうど \(r\) 個含まれていればその部分を消せるので、帰納法の仮定よりすべてのボールを消すことができます。

そのような区間が存在しなければ、鳩ノ巣原理より、赤いボールが \(r\) 個より多く含まれている区間と \(r\) 個未満含まれている区間がそれぞれ一つ以上存在します。 よってある隣り合う区間が存在し、一方には赤いボールが \(r\) 個より多く、もう一方には赤いボールが \(r\) 個未満含まれています。 この区間の一方から開始し、ボール一つ分ずつ区間をスライドさせていくと、各時点で区間に含まれている赤いボールの数は高々一つずつ変わっていきます。 よってある時点で赤いボールがちょうど \(r\) 個になるので、その部分を消すと帰納法の仮定よりすべてのボールを消すことができます。

これですべてのボールを消すことができる必要十分条件が分かりました。

さて、元の問題の解説に戻りましょう。

列の最初から \(i\) 番目までの部分列を \(s_i\)\(s_i\) に対する最適値を \(d_i\) とします。

\(i\) に対し、\(s_i\) の最適値を達成する操作において、 i 番目のボールは消えているか残っています。

残っている場合、\(d_i = d_{i - 1}\) です。

消えている場合は、白いボールを適切に塗り分けることで、 \(j + 1\) 番目から \(i\) 番目までの部分列に含まれる赤いボールの数が \(r \times k\) 個、青いボールの数が \(b \times k\) 個となるような \(j\), \(k\) が存在し、\(d_i = d_j + k\) となります。

以上より \(d_i = max \{d_{i - 1}, max \{d_j + k | j = i - (r + b) \times k、j と k は上の条件を満たす\}\}\) となります。

\(i\), \(k\), \(j = i - (r + b) \times k\) に対し、上の条件が成り立つかどうかは適当な前処理をすると定数時間で判定できるので、全体の計算量は \(O(N ^ 2 / (r + b))\) となります。

分割統治平面走査を行うことで、計算量を \(O(N(logN)^2)\) に改善することができます。

#include <atcoder/segtree>
#include <algorithm>
#include <iostream>
#include <map>

using namespace std;
using namespace atcoder;

typedef long long ll;
typedef pair<ll, int> P;

const int INF = 10000000;

int op(int a, int b){
    return max(a, b);
}

int e(){
    return -INF;
}

ll l[200005], u[200005];
int dp[200005];
P p[200005];

void divide_conquer(int a, int b, int d){
    if(b - a <= d){
        for(int i = max(a, 1); i < b; i++) dp[i] = max(dp[i], dp[i - 1]);
        return;
    }
    divide_conquer(a, (a + b) / 2, d);
    for(int c = 0; c < d; c++){
        if(a + c + d >= b) continue;
        map<ll, int> mp;
        int m = 0, k = 0;
        for(int i = a + c; i < b; i += d) mp[l[i]] = m++;
        for(auto itr = mp.begin(); itr != mp.end(); itr++) mp[itr->first] = k++;
        for(int e = 0; e < m; e++){
            int i = a + c + d * e;
            p[e] = P(u[i], i);
        }
        sort(p, p + m);
        segtree<int, op, e> seg(k);
        for(int e = 0; e < m; e++){
            int i = p[e].second;
            if(i < (a + b) / 2) seg.set(mp[l[i]], max(seg.get(mp[l[i]]), dp[i] - i / d));
            else dp[i] = max(dp[i], seg.prod(mp[l[i]], k) + i / d);
        }
    }
    divide_conquer((a + b) / 2, b, d);
}

void solve(){
    int n, r, b;
    string s;
    cin >> n >> r >> b >> s;
    int d = r + b;
    for(int i = 1; i <= n; i++){
        if(s[i - 1] == 'R') l[i] = l[i - 1] + b;
        else l[i] = l[i - 1] - r;
        if(s[i - 1] == 'B') u[i] = u[i - 1] - r;
        else u[i] = u[i - 1] + b;
    }
    divide_conquer(0, n + 1, d);
    cout << dp[n] << endl;
}

int main()
{
    solve();
}

posted:
last update: