Official

E - Popcount Sum 3 Editorial by mechanicalpenciI


問題文中の \(x\) に対する条件、「\(N\) 以下の正の整数」を「\(N\) 以下の非負整数」と言い換えても答えは変わらないため、その条件下で答えを求めます。

また、\(i\geq j\) をみたす非負整数の組 \((i,j)\) に対して、\(2^i\) 未満の非負整数であって、popcount の値が \(j\) であるようなものの個数を \(f(i,j)\), 総和を \(s(i,j)\) として定めます。これらの値は動的計画法を用いて計算することができます。 具体的には、\(f(i,0)=f(i,i)=1\), \(s(i,0)=0\), \(s(i,i)=2^i-1\) であり、\(0<j<i\) のとき \(f(i,j)=f(i-1,j-1)+f(i-1,j)\) および \(s(i,j)=s(i-1,j)+s(i-1,j-1)+2^{i-1}\times f(i-1,j-1)\) であることを用いれば良いです。
また数学的に計算する事で、\(f(i,j)=\binom{i}{j}\), \(s(i,0)=0\), \(s(i,j)=\binom{i-1}{j-1}\times (2^i-1)\)\(j>0\) のとき)となることが証明できます。

\(\mathrm{popcount}(N)=m\) とし、\(N\) を二進数表記したときに \(2^i\) の位が \(1\) である \(i\) を降順に \(d_1,d_2,\ldots, d_m\) \((d_1>d_2\cdots>d_m)\) とします。すなわち、\(N=\displaystyle\sum_{i=1}^m 2^{d_i}\) です。
このとき、\(0\) 以上 \(N\) 以下の整数を次のように \((m+1)\) 個のグループに分け、それぞれにおいて条件をみたすものについての総和を求めることを考えます。\(1\leq i\leq m\) について \(i\) 番目のグループは \(\displaystyle\sum_{j=1}^{i-1} 2^{d_j}\) 以上 \(\displaystyle\sum_{j=1}^{i} 2^{d_j}\) 未満の整数全体を表し、 \((m+1)\) 番目のグループは \(N\) のみからなるものとします。 \((m+1)\) 番目のグループについては \(m=K\) ならば \(N\), そうでないならば \(0\) です。

\(i\) 番目のグループ \((1\leq i\leq m)\) について、属する整数を二進数表記したときの \(2^{d_i}\) の位以上のうち \(1\) である桁は \(d_1, d_2,\ldots, d_{i-1}\) のみであり、\(2^{d_i-1}\) の位以下の桁はすべての組み合わせをとることから、条件をみたすものの個数は \(f(d_i,K-i+1)\) 個であり、それらの総和は \(S(d_i,K-i+1)+f(d_i,K-i+1)\times \displaystyle\sum_{j=1}^{i-1} 2^{d_j}\) となります。なお、\(K-i+1<0\) または \(d_i<K-i+1\) ならば \(0\) 個であり、総和も \(0\) となります。

事前に\(f(i,j)\), \(S(i,j)\) を計算しておくことで、各グループ \(O(1)\) で計算できるため、各テストケースごとに \(O(\log N)\) で解くことができます。また、\(f(i,j)\), \(S(i,j)\)\(0\leq j\leq i\leq\lfloor \log_2 N\rfloor\) の範囲について計算しておけば十分であるため、前計算にかかる時間計算量は \(O((\log N)^2)\) となります。 よって、計算量は全体で \(O((\log N+T)\log N)\) であり、問題の制約下で十分高速です。ゆえに、この問題を解くことができました。

\(998244353\) で割った余りを出力することがあることに注意してください。

c++ による実装例:

#include <bits/stdc++.h>
#include <atcoder/modint>

using namespace std;
using namespace atcoder;
using mint = modint998244353;

#define N 60

mint c[N+1][N+1]; //c[i][j]=(the number of x  s.t. 0<=x<(2^i), popcount(x)=j) for 0<=j<=i<=60
mint s[N+1][N+1]; //s[i][j]=(the sum    of x  s.t. 0<=x<(2^i), popcount(x)=j) for 0<=j<=i<=60

void preset(void){
	for(int i=0;i<=N;i++)for(int j=0;j<=N;j++){
		c[i][j]=0,s[i][j]=0;
	}
	c[0][0]=1;
	for(int i=0;i<60;i++){
		for(int j=0;j<=i;j++){
			c[i+1][j+1]+=c[i][j];
			s[i+1][j+1]+=s[i][j];
			s[i+1][j+1]+=(c[i][j]*((mint)2).pow(i));
			c[i+1][j]+=c[i][j];
			s[i+1][j]+=s[i][j];
		}
	}
	return;
}

int solve(long long n,int k){
	int a[N];
	for(int i=0;i<N;i++){
		a[i]=n&1;
		n=(n>>1);
	}
	int cur=0;
	mint offset=0;
	mint ans=0;
	for(int i=N-1;i>=0;i--){
		if(a[i]==1){
			if(cur<=k){
				ans+=s[i][k-cur];
				ans+=offset*c[i][k-cur];
			}
			cur++;
			offset+=((mint)2).pow(i);
		}
	}
	if(cur==k)ans+=offset;
	return (ans.val());
}

int main(void){
	preset();
	int t,k;
	long long n;
	cin>>t;
	for(int i=0;i<t;i++){
		cin>>n>>k;
		cout<<solve(n,k)<<endl;
	}
	return 0;
}

posted:
last update: