D - Masked Popcount Editorial
by
physics0523
この解説では、以下の表現を用います。
\(x\) の \(j\) bit 目 … \(x\) を \(2\) 進法で記述した際の \(2^j\) の位
\(x\) の \(j\) bit 目が立っている … \(x\) を \(2\) 進法で記述した際の \(2^j\) の位が \(1\) である
式を整理すると、次の通りに変形できます。
\(\displaystyle \sum_{k=0}^{N} \rm{popcount}\)\((k \mathbin{\&} M) = \) \(\displaystyle \sum_{k=0}^{N}\)\(((k \mathbin{\&} M)\) にて立っている bit の個数 \()\)
\(\displaystyle = \sum_{k=0}^{N} \sum_{j=0}^{\infty}\) \(((k \mathbin{\&} M)\) にて \(j\) bit 目が立っていれば \(1\) 、立っていなければ \(0\) \()\)
\(\displaystyle = \sum_{j=0}^{\infty} \sum_{k=0}^{N}\) \(((k \mathbin{\&} M)\) にて \(j\) bit 目が立っていれば \(1\) 、立っていなければ \(0\) \()\)
\(((k \mathbin{\&} M)\) にて \(j\) bit 目が立っていれば \(1\) 、立っていなければ \(0\) \()\) を更に整理すると、次の通りに言い換えられます。
- もし \(M\) の \(j\) bit 目が立っていなければ、常に \(0\)
- もし \(M\) の \(j\) bit 目が立っていれば、次の通り
- \(k\) の \(j\) bit 目が立っていれば \(1\)
- \(k\) の \(j\) bit 目が立っていなければ \(0\)
この問題の制約内では、 \(M\) の \(60\) bit 目とそれより上の位は常に \(0\) であることに注意すると、最終的に求める和は以下の手順で求めることができます。
- \(j=0,1,\dots,59\) について、以下を繰り返す。
- \(M\) の \(j\) bit 目が立っていなければ何もしない。
- \(M\) の \(j\) bit 目が立っていれば、 \(0\) 以上 \(N\) 以下の整数のうち \(j\) bit 目が立っているものの個数を答えに加算する。
さて、後は以下の数え上げが出来ればこの問題に正解できます。
\(0\) 以上 \(N\) 以下の整数のうち \(j\) bit 目が立っているものの個数を求めよ。
例えば \(j=2\) として考えてみましょう。
\(2\) bit 目が立っている整数は \(4,5,6,7,12,13,14,15,20,21,22,23,28\dots\) です。
この例を観察すると、次の事実が分かります。
- \(k\) を非負整数とした際、 \(0\) 以上 \(k \times 2^{j+1}\) 未満の整数のうち、 \(j\) bit 目が立っているものは \(k \times 2^j\) 個ある。
これは、全ての非負整数 \(i\) について \(i\) の \(j\) bit 目と \(i+2^{j+1}\) の \(j\) bit 目は一致することと、 \(0\) 以上 \(2^{j+1}\) 未満の整数のうち \(j\) bit 目が立っているものは丁度 \(2^j\) 個あることを利用すると証明可能です。
さらに、次のことも言えます。
- \(k\) を非負整数、 \(l\) を \(2^{j+1}\) 未満の整数としたとき、 \(k \times 2^{j+1}\) 以上 \(k \times 2^{j+1} + l\) 以下の整数のうち \(j\) bit 目が立っているものの数は以下の通りである。
- \(l\) が \(2^j\) 未満のとき、 \(0\)
- \(l\) が \(2^j\) 以上のとき、 \(l - 2^j + 1\)
このふたつを統合することで、\(0\) 以上 \(N\) 以下の整数のうち \(j\) bit 目が立っているものの個数を求めることができます。
以上より、この問題に正解することができました。
真の値が \(64\)bit 整数型に収まらないケースがあるので、答えを \(998244353\) で割った余りを取る際に十分注意してください。
実装例1 (C++):
#include<bits/stdc++.h>
#define mod 998244353
using namespace std;
long long f(long long j,long long n){
long long p2=(1ll<<j); // 2^j
long long k=n/(2*p2);
long long res=k*p2;
long long l=n%(2*p2);
if(l>=p2){
res+=(l-p2+1);
}
return res;
}
int main(){
long long n,m;
cin >> n >> m;
long long res=0;
for(long long i=0;i<60;i++){
if(m&(1ll<<i)){
res+=f(i,n);
res%=mod;
}
}
cout << res << "\n";
return 0;
}
bit演算を駆使した実装は次の通りです。
実装例2 (C++):
#include<bits/stdc++.h>
#define mod 998244353
using namespace std;
long long f(long long j,long long n){
long long res=((n>>(j+1))<<j);
if(n&(1ll<<j)){
res+=((n&((1ll<<j)-1))+1);
}
return res;
}
int main(){
long long n,m;
cin >> n >> m;
long long res=0;
for(long long i=0;i<60;i++){
if(m&(1ll<<i)){
res+=f(i,n);
res%=mod;
}
}
cout << res << "\n";
return 0;
}
posted:
last update: