M - グループ分け / Grouping Editorial by potato167


この問題は \(O(N2^{N})\) で解くことができます。

公式解説と同じように解くことを考えると、以下の提出のようになります。

実装例 c++ 1065 ms

この実装では、vector base(1 << N) が、集合 \(i\) のグループ化が可能なら base[i] = 1 そうでないならbase[i] = 0 になるようにしています。

そして、答えを求めるために何乗もするパートは以下のようになっています。

#include <atcoder/modint>
using mint = atcoder::modint998244353;
void solve(){
    // 何乗もする
    int ans = 0;
    vector<mint> dp(1 << N);
    dp[0] = 1;
    while (true){
        ans++;
        dp = po167::or_convolution(dp, base);
        if (dp.back().val()){
            cout << ans << "\n";
            break;
        }
    }
}

or_convolution とは、添字を bitwise or する畳み込みで、これの答えの配列の末尾が \(998244353\) で割ったあまりが \(0\) でなければグループが存在するということにしています。

この while 文の中身の計算量は \(O(N2^{N})\) で、これを最大で \(N\) 回するところがボトルネックとなって計算量が \(O(N^{2}2^{N})\) となっています。この部分の計算量を \(O(2^{N})\) にしたいです。

or_convolution を分解すると、以下のようになります。

void solve(){
    // 何乗もする
    int ans = 0;
    vector<mint> dp(1 << N);
    dp[0] = 1;
    while (true){
        ans++;
        po167::or_fwt(base);
        po167::or_fwt(dp);
        for (int i = 0; i < (1 << N); i++){
            dp[i] *= base[i];
        }
        po167::or_ifwt(dp);
        po167::or_ifwt(base);
        if (dp.back().val()){
            cout << ans << "\n";
            break;
        }
    }
}

po167::or_fwt(base), or_ifwt(base) を何回もしているのは無駄で、 while 文の前で \(1\)po167::or_fwt(base) をすれば十分です。また、dp.back() を求めるだけなら、毎回全体を復元しなくて良いです。

具体的には、以下のようにすれば良いです。

void solve(){
    // 何乗もする
    int ans = 0;
    vector<mint> dp(1 << N);
    dp[0] = 1;
    po167::or_fwt(base);
    po167::or_fwt(dp);
    while (true){
        ans++;
        mint dpb = 0;
        for (int i = 0; i < (1 << N); i++){
            dp[i] *= base[i];
            dpb += dp[i] * ((N - std::popcount((unsigned int) i)) & 1 ? 1 : -1);
        }
        if (dpb.val()){
            cout << ans << "\n";
            break;
        }
    }
}

これで、 while 文の中身が \(O(2^{N})\) となったため、全体の計算量は \(O(N2^{N})\) となりました。

実装例 c++ 103 ms

posted:
last update: