Official

D - Masked Popcount Editorial by physics0523


この解説では、以下の表現を用います。

\(x\)\(j\) bit 目 … \(x\)\(2\) 進法で記述した際の \(2^j\) の位
\(x\)\(j\) bit 目が立っている … \(x\)\(2\) 進法で記述した際の \(2^j\) の位が \(1\) である


式を整理すると、次の通りに変形できます。

\(\displaystyle \sum_{k=0}^{N} \rm{popcount}\)\((k \mathbin{\&} M) = \) \(\displaystyle \sum_{k=0}^{N}\)\(((k \mathbin{\&} M)\) にて立っている bit の個数 \()\)
\(\displaystyle = \sum_{k=0}^{N} \sum_{j=0}^{\infty}\) \(((k \mathbin{\&} M)\) にて \(j\) bit 目が立っていれば \(1\) 、立っていなければ \(0\) \()\)
\(\displaystyle = \sum_{j=0}^{\infty} \sum_{k=0}^{N}\) \(((k \mathbin{\&} M)\) にて \(j\) bit 目が立っていれば \(1\) 、立っていなければ \(0\) \()\)

\(((k \mathbin{\&} M)\) にて \(j\) bit 目が立っていれば \(1\) 、立っていなければ \(0\) \()\) を更に整理すると、次の通りに言い換えられます。

  • もし \(M\)\(j\) bit 目が立っていなければ、常に \(0\)
  • もし \(M\)\(j\) bit 目が立っていれば、次の通り
    • \(k\)\(j\) bit 目が立っていれば \(1\)
    • \(k\)\(j\) bit 目が立っていなければ \(0\)

この問題の制約内では、 \(M\)\(60\) bit 目とそれより上の位は常に \(0\) であることに注意すると、最終的に求める和は以下の手順で求めることができます。

  • \(j=0,1,\dots,59\) について、以下を繰り返す。
    • \(M\)\(j\) bit 目が立っていなければ何もしない。
    • \(M\)\(j\) bit 目が立っていれば、 \(0\) 以上 \(N\) 以下の整数のうち \(j\) bit 目が立っているものの個数を答えに加算する。

さて、後は以下の数え上げが出来ればこの問題に正解できます。

\(0\) 以上 \(N\) 以下の整数のうち \(j\) bit 目が立っているものの個数を求めよ。

例えば \(j=2\) として考えてみましょう。
\(2\) bit 目が立っている整数は \(4,5,6,7,12,13,14,15,20,21,22,23,28\dots\) です。
この例を観察すると、次の事実が分かります。

  • \(k\) を非負整数とした際、 \(0\) 以上 \(k \times 2^{j+1}\) 未満の整数のうち、 \(j\) bit 目が立っているものは \(k \times 2^j\) 個ある。

これは、全ての非負整数 \(i\) について \(i\)\(j\) bit 目と \(i+2^{j+1}\)\(j\) bit 目は一致することと、 \(0\) 以上 \(2^{j+1}\) 未満の整数のうち \(j\) bit 目が立っているものは丁度 \(2^j\) 個あることを利用すると証明可能です。

さらに、次のことも言えます。

  • \(k\) を非負整数、 \(l\)\(2^{j+1}\) 未満の整数としたとき、 \(k \times 2^{j+1}\) 以上 \(k \times 2^{j+1} + l\) 以下の整数のうち \(j\) bit 目が立っているものの数は以下の通りである。
    • \(l\)\(2^j\) 未満のとき、 \(0\)
    • \(l\)\(2^j\) 以上のとき、 \(l - 2^j + 1\)

このふたつを統合することで、\(0\) 以上 \(N\) 以下の整数のうち \(j\) bit 目が立っているものの個数を求めることができます。

以上より、この問題に正解することができました。
真の値が \(64\)bit 整数型に収まらないケースがあるので、答えを \(998244353\) で割った余りを取る際に十分注意してください。

実装例1 (C++):

#include<bits/stdc++.h>
#define mod 998244353

using namespace std;

long long f(long long j,long long n){
  long long p2=(1ll<<j); // 2^j
  long long k=n/(2*p2);
  long long res=k*p2;
  long long l=n%(2*p2);
  if(l>=p2){
    res+=(l-p2+1);
  }
  return res;
}

int main(){
  long long n,m;
  cin >> n >> m;
  long long res=0;
  for(long long i=0;i<60;i++){
    if(m&(1ll<<i)){
      res+=f(i,n);
      res%=mod;
    }
  }
  cout << res << "\n";
  return 0;
}

bit演算を駆使した実装は次の通りです。

実装例2 (C++):

#include<bits/stdc++.h>
#define mod 998244353

using namespace std;

long long f(long long j,long long n){
  long long res=((n>>(j+1))<<j);
  if(n&(1ll<<j)){
    res+=((n&((1ll<<j)-1))+1);
  }
  return res;
}

int main(){
  long long n,m;
  cin >> n >> m;
  long long res=0;
  for(long long i=0;i<60;i++){
    if(m&(1ll<<i)){
      res+=f(i,n);
      res%=mod;
    }
  }
  cout << res << "\n";
  return 0;
}

posted:
last update: