Submission #20934340


Source Code Expand

Copy
import sys
import numpy as np
import numba
from numba import njit, b1, i1, i4, i8, f8

read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
readlines = sys.stdin.buffer.readlines

MOD = 998_244_353

def from_read(dtype=np.int64):
    return np.fromstring(read().decode(), dtype=dtype, sep=' ')


def from_readline(dtype=np.int64):
    return np.fromstring(readline().decode(), dtype=dtype, sep=' ')

@njit
def mpow(a, n):
    p = 1
    while n:
        if n & 1:
            p = p * a % MOD
        a = a * a % MOD
        n >>= 1
    return p

@njit
def power_table(a, N):
    A = np.ones(N + 1, np.int64)
    for n in range(N):
        A[n + 1] = A[n] * a % MOD
    return A

@njit((i8, i8), cache=True)
def main(N, M):
    POWER = np.zeros((M + 1, N + 1), np.int64)
    for x in range(M + 1):
        POWER[x] = power_table(x, N)

    ans = 0
    """for L in range(0,N):
        for R in range(L+1,N+1):
            for H in range(1, M+1):
                x = POWER[M-H+1,R-L]
                x -= POWER[M-H,R-L]
                if L == 0 and R == N:
                    y = 1
                elif L == 0 and R < N:
                    y = (H-1) * POWER[M,N-(R-L)-1] % MOD
                elif 0 < N and R == N:
                    y = (H-1) * POWER[M,N-(R-L)-1] % MOD
                else:
                    y = (H-1) * (H-1) * POWER[M,N-(R-L)-2] % MOD
                x = x * y % MOD
                ans += x"""
    for H in range(1, M + 1):
        for size in range(1, N + 1):
            x = POWER[M - H + 1, size] - POWER[M - H, size]
            if size == N:
                y = 1
            else:
                y = 0
                y += (H - 1) * POWER[M, N - size - 1] % MOD
                y += (H - 1) * POWER[M, N - size - 1] % MOD
                if size <= N - 2:
                    z = (H - 1) * (H - 1) * POWER[M, N - size - 2] % MOD
                    y += z * (N - 1 - size) % MOD
            x = x * y % MOD
            ans += x
    print(ans % MOD)

a, b = map(int, read().split())

main(a, b)

Submission Info

Submission Time
Task C - Sequence Scores
User maspy
Language Python (3.8.2)
Score 600
Code Size 2113 Byte
Status AC
Exec Time 1034 ms
Memory 300992 KB

Judge Result

Set Name Sample All
Score / Max Score 0 / 0 600 / 600
Status
AC × 3
AC × 22
Set Name Test Cases
Sample s1.txt, s2.txt, s3.txt
All 01.txt, 02.txt, 03.txt, 04.txt, 05.txt, 06.txt, 07.txt, 08.txt, 09.txt, 10.txt, 11.txt, 12.txt, 13.txt, 14.txt, 15.txt, 16.txt, 17.txt, 18.txt, 19.txt, s1.txt, s2.txt, s3.txt
Case Name Status Exec Time Memory
01.txt AC 504 ms 107376 KB
02.txt AC 480 ms 106440 KB
03.txt AC 517 ms 119656 KB
04.txt AC 571 ms 139676 KB
05.txt AC 492 ms 106456 KB
06.txt AC 588 ms 140172 KB
07.txt AC 665 ms 166824 KB
08.txt AC 684 ms 178888 KB
09.txt AC 522 ms 120700 KB
10.txt AC 604 ms 150316 KB
11.txt AC 670 ms 176072 KB
12.txt AC 781 ms 216320 KB
13.txt AC 592 ms 145884 KB
14.txt AC 703 ms 185696 KB
15.txt AC 737 ms 200240 KB
16.txt AC 852 ms 242472 KB
17.txt AC 478 ms 106680 KB
18.txt AC 1034 ms 300992 KB
19.txt AC 1033 ms 300472 KB
s1.txt AC 498 ms 105412 KB
s2.txt AC 480 ms 106684 KB
s3.txt AC 480 ms 105984 KB