I - Cake Division 解説 by KKT89

判定問題・条件を満たす切り目の数え上げを線形時間で行う方法

解説の方針と同様に、次の判定問題を解くことを考えます。

\(N\) 個の円環上に並ぶピースを \(K\) 個の区間に分割し、各区間の質量の総和を \(X\) 以上にすることが可能か?

二分探索内で上記の判定問題を \(O(N)\) で解き、得られた質量の最大値に対して、条件を満たす切り目の数え上げを \(O(N)\) で計算する方針を紹介します。

以降、円環は考えにくいので、\(1\) 番目のピースと \(N\) 番目のピースの間で切り開かれたものとして考えることにします。その代わり、必要時に左端と右端に注目し、質量の総和が \(X\) 以上の区間の個数を \(+1\) できるか?の判定を行うことが、本解法のメインアイデアとなります。

二分探索の判定問題

\(L_i :=\) \(i\) を左端として、区間 \([i, j]\) の総和が \(X\) 以上となる、右端 \(j\) の最小値 (ただし、そのような右端が存在しない場合は \(-1\))

とした時、この値は \(i \) の昇順に、尺取り法の要領で計算することで、\(O(N)\) で求めることができます。続いて、次のような値を計算します。

\(dp_i :=\) \(i\) を左端とする時、得られる質量の総和が \(X\) 以上の区間の個数の最大値 \(c\)、及びその場合における右端 \(r\) (区間 \([r, N]\) は未使用という情報を持つ)

これは \(i\) の降順に配列を埋めていくことで、同じく \(O(N)\) で求めることができます。

以上の前計算の結果を用いることで、判定問題全体も \(O(N)\) で解くことができます。具体的には、各 \(i\) について、\(dp_i\) の計算で求めた値 \(c, r\) と、区間 \([1, i)\)\([r, N]\) の総和を見ることで判定可能です。後者の計算は、累積和を予め取っておくことで容易です。

条件を満たす切れ目の数え上げ

\(L_i\)\(dp_i\) は、\(i\) を左端として計算したものですが、\(i\) を右端としたものについても同様に計算することで、この問題にも対応可能です。

\(R_i :=\) \(i\) を右端として、区間 \([j, i]\) の総和が \(X\) 以上となる、左端 \(j\) の最大値 (ただし、そのような左端が存在しない場合は \(N+1\))

\(dp^{\prime}_i :=\) \(i\) を右端とする時、得られる質量の総和が \(X\) 以上の区間の個数の最大値 \(c^\prime\)、及びその場合における左端 \(l\) (区間 \([1, l]\) は未使用という情報を持つ)

をそれぞれ前計算で求めておくことで、各 \(i\) について、\(dp_i, dp^{\prime}_{i-1}\) で求めた値 \(c, c^\prime, l, r\) と、\([1, l]\)\([r, N]\) の総和から、\(i-1\)\(i\) の間に切れ目を入れた場合に有効な切り分け方が存在するかどうか、\(O(N)\) で判定することができます。

実装例 (c++, 40ms)

#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;
typedef unsigned long long int ull;

int main() {
    cin.tie(nullptr);
    ios::sync_with_stdio(false);
    int N, K;
    cin >> N >> K;
    vector<int> a(N), r_sum(N + 1);
    for (int i = 0; i < N; ++i) {
        cin >> a[i];
        r_sum[i + 1] = r_sum[i] + a[i];
    }

    vector<int> L(N);
    vector<pair<int, int>> dp(N + 1);
    auto get_sum = [&](int l, int r) -> int {
        // calc [l, r)
        return r_sum[r] - r_sum[l];
    };
    auto slv1 = [&](int X) -> bool {
        // 二分探索の判定問題
        {
            // この枝刈りが無くても問題ないが、入れると定数倍改善で早くなる
            int cnt = 0, sum = 0;
            for (int i = 0; i < N; ++i) {
                sum += a[i];
                if (sum >= X) {
                    cnt += 1, sum = 0;
                }
            }
            if (cnt >= K) return true;
            if (cnt <= K - 2) return false;
        }
        int sum = 0;
        for (int i = 0, j = 0; i < N; ++i) {
            while (j < N and sum < X) {
                sum += a[j++];
            }
            L[i] = (sum >= X ? j : -1);
            sum -= a[i];
        }
        dp[N] = {0, N};
        for (int i = N - 1; i >= 0; --i) {
            if (L[i] == -1) {
                dp[i] = {0, i};
            } else {
                dp[i] = dp[L[i]];
                dp[i].first += 1;
            }
        }
        for (int i = 0; i < N; ++i) {
            auto [c, r] = dp[i];
            if (c <= K - 2) break;
            int cnt = c + (get_sum(0, i) + get_sum(r, N) >= X);
            if (cnt >= K) return true;
        }
        return false;
    };
    
    auto slv2 = [&](int X) -> int {
        // 条件を満たす切れ目の数え上げ
        vector<int> R(N);
        vector<pair<int, int>> dp2(N + 1);

        int sum = 0;
        for (int i = 0, j = 0; i < N; ++i) {
            while (j < N and sum < X) {
                sum += a[j++];
            }
            L[i] = (sum >= X ? j : -1);
            sum -= a[i];
        }
        dp[N] = {0, N};
        for (int i = N - 1; i >= 0; --i) {
            if (L[i] == -1) {
                dp[i] = {0, i};
            } else {
                dp[i] = dp[L[i]];
                dp[i].first += 1;
            }
        }

        sum = 0;
        for (int i = N - 1, j = N - 1; i >= 0; --i) {
            while (j >= 0 and sum < X) {
                sum += a[j--];
            }
            R[i] = (sum >= X ? j : N + 1);
            sum -= a[i];
        }
        // 実装の都合上 1-indexed
        dp2[0] = {0, -1};
        for (int i = 0; i < N; ++i) {
            if (R[i] == N + 1) {
                dp2[i + 1] = {0, i};
            } else {
                dp2[i + 1] = dp2[R[i] + 1];
                dp2[i + 1].first += 1;
            }
        }

        int res = 0;
        for (int i = 0; i < N; ++i) {
            auto [c, r] = dp[i];
            auto [c2, l] = dp2[i];
            int cnt = c + c2 + (get_sum(0, l + 1) + get_sum(r, N) >= X);
            if (cnt < K) res += 1;
        }
        return res;
    };

    ll ok = 0, ng = 2e9;
    while (ng - ok > 1) {
        int mid = (ok + ng) / 2;
        if (slv1(mid)) {
            ok = mid;
        } else {
            ng = mid;
        }
    }

    cout << ok << " " << slv2(ok) << endl;
}

投稿日時:
最終更新: