G - Row Column Sums 2 Editorial by TOMWT

Memorize dfs

The solution of A also introduces memorize dfs

Thanks \(\texttt{s\color{red}{weetorange}}\). I missed the fourth situation during the contest, and he helped me fix that.

Chinese version


Ignore \(R_i=0\) and \(C_i=0\).

Count the number of \(R_i=1\). Call it cnt1. Count the number of \(R_i=2\). Call it cnt2.

Count the number of \(C_i=1\). Call it cnt3. Count the number of \(C_i=2\). Call it cnt4.

No solution

\[cnt1+2\times cnt2\neq cnt3+2\times cnt4\]

Memorize dfs

dfs(i,j,k,l) means the answer of \(\sum\limits_i[R_i=1]=i\ \ \ \ \ \&\&\ \ \ \ \ \sum\limits_i[R_i=2]=j\ \ \ \ \ \&\&\ \ \ \ \ \sum\limits_i[C_i=1]=k\ \ \ \ \ \&\&\ \ \ \ \ \sum\limits_i[C_i=2]=l\).

When \(l>0\):

Note \(i+2\times j=k+l\times 2\) is hold.

Consider select a \(C=2\) and give the two 1s to:

  • a \(R=1\) and a \(R=2\)\(f(i-1+1,j-1,k,l-1)\times i\times j\)。Then the \(R=1\) disappears and the \(R=2\) becomes \(R=1\) now.
  • two \(R=1\)s:\(f(i-2,j,k,l-1)\times C_i^2\)。Then the two \(R=1\)s both disappear.
  • two \(R=2\)s:\(f(i+2,j-2,k,l-1)\times C_j^2\)。Then the two \(R=2\)s both become \(R=1\).
  • one \(R=2\)\(f(i,j-1,k,l-1)\times j\)

When \(l=0\):

Note \(i+2\times j=k\) is hold (This may help you understand the formula below).

\[C_k^1\times C_{k-1}^1\times \cdots\times C_{2\times j+1}^1\times C_{2\times j}^2\times C_{2\times(j-1)}^2\times\cdots\times C_2^2\]

\[=\frac{k!}{2^j}\]

The answer of the wholeproblem is dfs(cnt1,cnt2,cnt3,cnt4).

code

#include<stdio.h>
#include<string.h>
#define mod 998244353
int n,a,cnt1,cnt2,cnt3,cnt4,ans[5001][5001],fac[5001];
inline long long ksm(long long a,int b)//pow
{
	long long ans=1;
	for(;b;b>>=1,a*=a,a%=mod)if(b&1)ans*=a,ans%=mod;
	return ans;
}
inline long long dfs(const int&i,const int&j,const int&k,const int&l)
{
	if(!l)return fac[k]*ksm(ksm(2,j),mod-2)%mod;
	if(~ans[i][l])return ans[i][l];//memorize
	ans[i][l]=0;
	if(i&&j)ans[i][l]=(ans[i][l]+dfs(i,j-1,k,l-1)*i%mod*j)%mod;
	if(i>1)ans[i][l]=(ans[i][l]+dfs(i-2,j,k,l-1)*(i*(i-1ll)>>1))%mod;
	if(j>1)ans[i][l]=(ans[i][l]+dfs(i+2,j-2,k,l-1)*(j*(j-1ll)>>1))%mod;
	if(j)ans[i][l]=(ans[i][l]+dfs(i,j-1,k,l-1)*j)%mod;
	return ans[i][l];
}
main()
{
	fac[0]=1;for(int i=1;i<5001;fac[i]=(long long)(fac[i-1])*i%mod,++i);
	memset(ans,-1,sizeof(ans));scanf("%d",&n);
	for(int i=n;i--;scanf("%d",&a),a==1&&++cnt1,a==2&&++cnt2);
	for(int i=n;i--;scanf("%d",&a),a==1&&++cnt3,a==2&&++cnt4);
	if(cnt1+cnt2+cnt2^cnt3+cnt4+cnt4){putchar('0');return 0;}//no solution
	printf("%lld",dfs(cnt1,cnt2,cnt3,cnt4));
}

posted:
last update: