提出 #20934340


ソースコード 拡げる

import sys
import numpy as np
import numba
from numba import njit, b1, i1, i4, i8, f8

read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
readlines = sys.stdin.buffer.readlines

MOD = 998_244_353

def from_read(dtype=np.int64):
    return np.fromstring(read().decode(), dtype=dtype, sep=' ')


def from_readline(dtype=np.int64):
    return np.fromstring(readline().decode(), dtype=dtype, sep=' ')

@njit
def mpow(a, n):
    p = 1
    while n:
        if n & 1:
            p = p * a % MOD
        a = a * a % MOD
        n >>= 1
    return p

@njit
def power_table(a, N):
    A = np.ones(N + 1, np.int64)
    for n in range(N):
        A[n + 1] = A[n] * a % MOD
    return A

@njit((i8, i8), cache=True)
def main(N, M):
    POWER = np.zeros((M + 1, N + 1), np.int64)
    for x in range(M + 1):
        POWER[x] = power_table(x, N)

    ans = 0
    """for L in range(0,N):
        for R in range(L+1,N+1):
            for H in range(1, M+1):
                x = POWER[M-H+1,R-L]
                x -= POWER[M-H,R-L]
                if L == 0 and R == N:
                    y = 1
                elif L == 0 and R < N:
                    y = (H-1) * POWER[M,N-(R-L)-1] % MOD
                elif 0 < N and R == N:
                    y = (H-1) * POWER[M,N-(R-L)-1] % MOD
                else:
                    y = (H-1) * (H-1) * POWER[M,N-(R-L)-2] % MOD
                x = x * y % MOD
                ans += x"""
    for H in range(1, M + 1):
        for size in range(1, N + 1):
            x = POWER[M - H + 1, size] - POWER[M - H, size]
            if size == N:
                y = 1
            else:
                y = 0
                y += (H - 1) * POWER[M, N - size - 1] % MOD
                y += (H - 1) * POWER[M, N - size - 1] % MOD
                if size <= N - 2:
                    z = (H - 1) * (H - 1) * POWER[M, N - size - 2] % MOD
                    y += z * (N - 1 - size) % MOD
            x = x * y % MOD
            ans += x
    print(ans % MOD)

a, b = map(int, read().split())

main(a, b)

提出情報

提出日時
問題 C - Sequence Scores
ユーザ maspy
言語 Python (3.8.2)
得点 600
コード長 2113 Byte
結果 AC
実行時間 1034 ms
メモリ 300992 KiB

ジャッジ結果

セット名 Sample All
得点 / 配点 0 / 0 600 / 600
結果
AC × 3
AC × 22
セット名 テストケース
Sample s1.txt, s2.txt, s3.txt
All 01.txt, 02.txt, 03.txt, 04.txt, 05.txt, 06.txt, 07.txt, 08.txt, 09.txt, 10.txt, 11.txt, 12.txt, 13.txt, 14.txt, 15.txt, 16.txt, 17.txt, 18.txt, 19.txt, s1.txt, s2.txt, s3.txt
ケース名 結果 実行時間 メモリ
01.txt AC 504 ms 107376 KiB
02.txt AC 480 ms 106440 KiB
03.txt AC 517 ms 119656 KiB
04.txt AC 571 ms 139676 KiB
05.txt AC 492 ms 106456 KiB
06.txt AC 588 ms 140172 KiB
07.txt AC 665 ms 166824 KiB
08.txt AC 684 ms 178888 KiB
09.txt AC 522 ms 120700 KiB
10.txt AC 604 ms 150316 KiB
11.txt AC 670 ms 176072 KiB
12.txt AC 781 ms 216320 KiB
13.txt AC 592 ms 145884 KiB
14.txt AC 703 ms 185696 KiB
15.txt AC 737 ms 200240 KiB
16.txt AC 852 ms 242472 KiB
17.txt AC 478 ms 106680 KiB
18.txt AC 1034 ms 300992 KiB
19.txt AC 1033 ms 300472 KiB
s1.txt AC 498 ms 105412 KiB
s2.txt AC 480 ms 106684 KiB
s3.txt AC 480 ms 105984 KiB