```from collections import defaultdict
N=int(input())
A=[int(i) for i in input().split()]
dd=defaultdict(int)
for a in A:
dd[a]+=1
#mod table !
mod=998244353
table=[1]*1002
s=1
for i in range(1,1002):
s*=i
table[i]=s%mod
rtable=[1]*1002
s=1
for i in range(1,1002):
s*=pow(i,mod-2,mod)
rtable[i]=s%mod
#jk
dp=[[0]*(N+1) for i in range(N+2)]
dp[N+1][0]=1
for i in range(N,0,-1):
for j in range(N+1):
num=dd[i]+j
if num>N:
continue
k=0
while k*i<=num:
#print(i,j,k*i)
num=dd[i]+j
C=table[num]
C*=rtable[num-k*i]
C%=mod
C*=pow(rtable[i],k,mod)
C%=mod
C*=rtable[k]
C%mod
dp[i][num-k*i]+=C*dp[i+1][j]
dp[i][num-k*i]%=mod
k+=1
print(dp[1][0])```

Submission Time 2018-08-12 16:56:08+0900 F - チーム分け okumura PyPy3 (2.4.0) 0 884 Byte TLE 2109 ms 80984 KB

