M - グループ分け / Grouping Editorial
by
potato167
この問題は \(O(N2^{N})\) で解くことができます。
公式解説と同じように解くことを考えると、以下の提出のようになります。
この実装では、vector base(1 << N) が、集合 \(i\) のグループ化が可能なら base[i] = 1 そうでないならbase[i] = 0 になるようにしています。
そして、答えを求めるために何乗もするパートは以下のようになっています。
#include <atcoder/modint>
using mint = atcoder::modint998244353;
void solve(){
// 何乗もする
int ans = 0;
vector<mint> dp(1 << N);
dp[0] = 1;
while (true){
ans++;
dp = po167::or_convolution(dp, base);
if (dp.back().val()){
cout << ans << "\n";
break;
}
}
}
or_convolution とは、添字を bitwise or する畳み込みで、これの答えの配列の末尾が \(998244353\) で割ったあまりが \(0\) でなければグループが存在するということにしています。
この while 文の中身の計算量は \(O(N2^{N})\) で、これを最大で \(N\) 回するところがボトルネックとなって計算量が \(O(N^{2}2^{N})\) となっています。この部分の計算量を \(O(2^{N})\) にしたいです。
or_convolution を分解すると、以下のようになります。
void solve(){
// 何乗もする
int ans = 0;
vector<mint> dp(1 << N);
dp[0] = 1;
while (true){
ans++;
po167::or_fwt(base);
po167::or_fwt(dp);
for (int i = 0; i < (1 << N); i++){
dp[i] *= base[i];
}
po167::or_ifwt(dp);
po167::or_ifwt(base);
if (dp.back().val()){
cout << ans << "\n";
break;
}
}
}
po167::or_fwt(base), or_ifwt(base) を何回もしているのは無駄で、 while 文の前で \(1\) 回 po167::or_fwt(base) をすれば十分です。また、dp.back() を求めるだけなら、毎回全体を復元しなくて良いです。
具体的には、以下のようにすれば良いです。
void solve(){
// 何乗もする
int ans = 0;
vector<mint> dp(1 << N);
dp[0] = 1;
po167::or_fwt(base);
po167::or_fwt(dp);
while (true){
ans++;
mint dpb = 0;
for (int i = 0; i < (1 << N); i++){
dp[i] *= base[i];
dpb += dp[i] * ((N - std::popcount((unsigned int) i)) & 1 ? 1 : -1);
}
if (dpb.val()){
cout << ans << "\n";
break;
}
}
}
これで、 while 文の中身が \(O(2^{N})\) となったため、全体の計算量は \(O(N2^{N})\) となりました。
posted:
last update: