E - Median Replace Editorial by potato167


\(0\)\(-1\) に置き換え、文字列を数列としてみると、問題は以下のようになります。

「奇数長の連続部分列を選んだ後消す。消した部分列の総和が負なら \(-1\) を、正であるなら \(1\) をその場所に挿入する」という操作をして、\(1\) が一個だけ残るようにできる数列は何個ありますか。

数列の総和が正にすれば良いので、総和が \(-3\) である区間を消すことで、数列全体の総和を増やしていきます。また、その回数を増やすためにできるだけ区間は短い方が良いです。よって、前から見ていき、総和が \(-3\) になるものを見つけ次第、それを \(-1\) に置き換えることを考えます。

\(dp[i][j][k]\)\(i\) 文字目まで見て、累積和の最大値が \(j\) で、\(i\) 文字目までの累積和と \(j\) との差が \(k\) であるような文字列の個数とします。

\(i\) 文字目が \(0(-1)\) のとき、累積和が減るので、以下のように更新します。

\[dp[i + 1][j][k + 1] += dp[i][j][k]\]

\(i\) 文字目が \(1\) のとき、累積和が増えます。\(k = 0\) か否かで更新の仕方が増えます。

  • \(k = 0\) のとき、\(dp[i + 1][j + 1][0] += dp[i][j][0]\)
  • \(k = 1\) のとき、\(dp[i + 1][j][k - 1] += dp[i][j][k]\)

また、\(k = 3\) になったとき、総和が \(-3\) であるような区間が存在するということになるので、以下のように更新します。

\[dp[i][j][1] += dp[i][j][3]\]

\[dp[i][j][3] \leftarrow 0\]

そして、\(dp[N][j][k]\) のうち、\(j > k\) であるようなものが条件を満たす数列です。逆に、\(j < k\) であるようなものが条件を満たさない数列です。\(k \leq 3\) であることから、条件を満たさない数列の数え上げをする際、\(j\)\(3\) 以上になるようなものは持たなくていいです。よって、dp の更新が \(O(1)\) になるため、計算量は \(O(N)\) です。

#include <bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for (int i=(int)(a);i<(int)(b);i++)
#include<atcoder/modint>
using mint = atcoder::modint1000000007;


int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    
    string S;
    cin >> S;
    int N = S.size();
    const int D = 4;
    vector dp(D, vector<mint>(D));
    dp[0][0] = 1;
    mint sum = 1;
    rep(i, 0, N){
        vector n_dp(D, vector<mint>(D));
        if (S[i] == '?') sum *= 2;
        if (S[i] == '0' || S[i] == '?'){
            rep(j, 0, D) rep(k, 0, D - 1){
                n_dp[j][k + 1] += dp[j][k];
            }
        }
        if (S[i] == '1' || S[i] == '?'){
            rep(j, 0, D) rep(k, 0, D){
                if (k == 0){
                    if (j != D - 1) n_dp[j + 1][k] += dp[j][k];
                }
                else n_dp[j][k - 1] += dp[j][k];
            }
        }
        swap(n_dp, dp);
        rep(j, 0, D) dp[j][1] += dp[j][3], dp[j][3] = 0;
    }
    rep(j, 0, D) rep(k, 0, D) if (j < k) sum -= dp[j][k];
    cout << sum.val() << "\n";
}

posted:
last update: