G - Row Column Sums 2 Editorial by TOMWT
Memorize dfsThe solution of A also introduces memorize dfs
Thanks \(\texttt{s\color{red}{weetorange}}\). I missed the fourth situation during the contest, and he helped me fix that.
Ignore \(R_i=0\) and \(C_i=0\).
Count the number of \(R_i=1\). Call it cnt1
. Count the number of \(R_i=2\). Call it cnt2
.
Count the number of \(C_i=1\). Call it cnt3
. Count the number of \(C_i=2\). Call it cnt4
.
No solution
\[cnt1+2\times cnt2\neq cnt3+2\times cnt4\]
Memorize dfs
dfs(i,j,k,l)
means the answer of \(\sum\limits_i[R_i=1]=i\ \ \ \ \ \&\&\ \ \ \ \ \sum\limits_i[R_i=2]=j\ \ \ \ \ \&\&\ \ \ \ \ \sum\limits_i[C_i=1]=k\ \ \ \ \ \&\&\ \ \ \ \ \sum\limits_i[C_i=2]=l\).
When \(l>0\):
Note \(i+2\times j=k+l\times 2\) is hold.
Consider select a \(C=2\) and give the two 1s to:
- a \(R=1\) and a \(R=2\):\(f(i-1+1,j-1,k,l-1)\times i\times j\)。Then the \(R=1\) disappears and the \(R=2\) becomes \(R=1\) now.
- two \(R=1\)s:\(f(i-2,j,k,l-1)\times C_i^2\)。Then the two \(R=1\)s both disappear.
- two \(R=2\)s:\(f(i+2,j-2,k,l-1)\times C_j^2\)。Then the two \(R=2\)s both become \(R=1\).
- one \(R=2\):\(f(i,j-1,k,l-1)\times j\)
When \(l=0\):
Note \(i+2\times j=k\) is hold (This may help you understand the formula below).
\[C_k^1\times C_{k-1}^1\times \cdots\times C_{2\times j+1}^1\times C_{2\times j}^2\times C_{2\times(j-1)}^2\times\cdots\times C_2^2\]
\[=\frac{k!}{2^j}\]
The answer of the wholeproblem is dfs(cnt1,cnt2,cnt3,cnt4)
.
code
#include<stdio.h>
#include<string.h>
#define mod 998244353
int n,a,cnt1,cnt2,cnt3,cnt4,ans[5001][5001],fac[5001];
inline long long ksm(long long a,int b)//pow
{
long long ans=1;
for(;b;b>>=1,a*=a,a%=mod)if(b&1)ans*=a,ans%=mod;
return ans;
}
inline long long dfs(const int&i,const int&j,const int&k,const int&l)
{
if(!l)return fac[k]*ksm(ksm(2,j),mod-2)%mod;
if(~ans[i][l])return ans[i][l];//memorize
ans[i][l]=0;
if(i&&j)ans[i][l]=(ans[i][l]+dfs(i,j-1,k,l-1)*i%mod*j)%mod;
if(i>1)ans[i][l]=(ans[i][l]+dfs(i-2,j,k,l-1)*(i*(i-1ll)>>1))%mod;
if(j>1)ans[i][l]=(ans[i][l]+dfs(i+2,j-2,k,l-1)*(j*(j-1ll)>>1))%mod;
if(j)ans[i][l]=(ans[i][l]+dfs(i,j-1,k,l-1)*j)%mod;
return ans[i][l];
}
main()
{
fac[0]=1;for(int i=1;i<5001;fac[i]=(long long)(fac[i-1])*i%mod,++i);
memset(ans,-1,sizeof(ans));scanf("%d",&n);
for(int i=n;i--;scanf("%d",&a),a==1&&++cnt1,a==2&&++cnt2);
for(int i=n;i--;scanf("%d",&a),a==1&&++cnt3,a==2&&++cnt4);
if(cnt1+cnt2+cnt2^cnt3+cnt4+cnt4){putchar('0');return 0;}//no solution
printf("%lld",dfs(cnt1,cnt2,cnt3,cnt4));
}
posted:
last update: