import sys
import numpy as np
import numba
from numba import njit, b1, i4, i8, f8

MOD = 998_244_353

@njit((i8, i8), cache=True)
def main(N, K):
dp = np.zeros((N + 1, 2 * N + 1), np.int64)  # 個数、和
dp[0, 0] = 1
for n in range(1, N + 1):
for s in range(n, 0, -1):
x = dp[n - 1, s - 1]
# 1 を使わない
x += dp[n, 2 * s]
if x >= MOD:
x -= MOD
dp[n, s] = x
return dp[N, K]

a, b = map(int, read().split())

print(main(a, b))

2020-10-31 21:11:24+0900 D - Number of Multisets maspy Python (3.8.2) AC

