import sys
import numpy as np
import numba
from numba import njit, b1, i1, i4, i8, f8
read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
readlines = sys.stdin.buffer.readlines
MOD = 998_244_353
def from_read(dtype=np.int64):
return np.fromstring(read().decode(), dtype=dtype, sep=' ')
def from_readline(dtype=np.int64):
return np.fromstring(readline().decode(), dtype=dtype, sep=' ')
@njit
def mpow(a, n):
p = 1
while n:
if n & 1:
p = p * a % MOD
a = a * a % MOD
n >>= 1
return p
@njit
def power_table(a, N):
A = np.ones(N + 1, np.int64)
for n in range(N):
A[n + 1] = A[n] * a % MOD
return A
@njit((i8, i8), cache=True)
def main(N, M):
POWER = np.zeros((M + 1, N + 1), np.int64)
for x in range(M + 1):
POWER[x] = power_table(x, N)
ans = 0
"""for L in range(0,N):
for R in range(L+1,N+1):
for H in range(1, M+1):
x = POWER[M-H+1,R-L]
x -= POWER[M-H,R-L]
if L == 0 and R == N:
y = 1
elif L == 0 and R < N:
y = (H-1) * POWER[M,N-(R-L)-1] % MOD
elif 0 < N and R == N:
y = (H-1) * POWER[M,N-(R-L)-1] % MOD
else:
y = (H-1) * (H-1) * POWER[M,N-(R-L)-2] % MOD
x = x * y % MOD
ans += x"""
for H in range(1, M + 1):
for size in range(1, N + 1):
x = POWER[M - H + 1, size] - POWER[M - H, size]
if size == N:
y = 1
else:
y = 0
y += (H - 1) * POWER[M, N - size - 1] % MOD
y += (H - 1) * POWER[M, N - size - 1] % MOD
if size <= N - 2:
z = (H - 1) * (H - 1) * POWER[M, N - size - 2] % MOD
y += z * (N - 1 - size) % MOD
x = x * y % MOD
ans += x
print(ans % MOD)
a, b = map(int, read().split())
main(a, b)