MOD = 998_244_353

def geom_seq_sum(r, n, MOD):
"""return \sum_{i=0}^{n-1} r^i"""
if r == 1:
return n % MOD
r1 = r - 1
return ((pow(r, n, r1 * MOD) - 1) // r1) % MOD

def geom_seq_sum_1(r,n,MOD):
"""
return \sum_{i=0}^{n-1} ir^i mod MOD
"""
if r == 1:
return n*(n-1)//2%MOD
r1 = r-1
x = pow(r, n, r1 * r1 * MOD)
return (-(x-1) // r1 + (n-1) * x + 1) // r1 % MOD

def main(N, M):
ans = 0
"""
size == N
"""
ans += pow(M, N, MOD)
if N == 1:
return ans
"""
size == N - 1
"""
for x in range(1, M):
ans += 2 * pow(x, N - 1, MOD)

c = pow(M, N-1, MOD)
M_inv = pow(M, -1, MOD)

for H in range(1, M + 1):
s = 0
r1 = (M - H + 1) * M_inv
r2 = (M - H) * M_inv
x = r1 * geom_seq_sum(r1, N - 2, MOD)
x -= r2 * geom_seq_sum(r2, N - 2, MOD)
ans += x * c * 2 * (H - 1)

for H in range(1, M + 1):
s = 0
c = pow(M, N - 3, MOD)
r1 = (M - H + 1) * M_inv
r2 = (M - H) * M_inv
s += (N-2) * (M-H+1) * geom_seq_sum(r1, N-2, MOD)
s -= (N-2) * (M-H) * geom_seq_sum(r2, N-2, MOD)

s -= (M - H + 1) * geom_seq_sum_1(r1, N-2, MOD)
s += (M - H) * geom_seq_sum_1(r2, N-2, MOD)
ans += (H - 1) * (H - 1) * c * s
return ans % MOD

a, b = map(int, input().split())

print(main(a, b))

