Submission #20952833


Source Code Expand

import sys
import numpy as np
import numba
from numba import njit, b1, i1, i4, i8, f8

read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
readlines = sys.stdin.buffer.readlines

def from_read(dtype=np.int64):
    return np.fromstring(read().decode(), dtype=dtype, sep=' ')


def from_readline(dtype=np.int64):
    return np.fromstring(readline().decode(), dtype=dtype, sep=' ')

MOD = 998_244_353

@njit
def mpow(a, n):
    p = 1
    while n:
        if n & 1:
            p = p * a % MOD
        a = a * a % MOD
        n >>= 1
    return p


@njit
def fact_table(N=1 << 20):
    N += 1
    fact = np.empty(N, np.int64)
    fact[0] = 1
    for n in range(1, N):
        fact[n] = n * fact[n - 1] % MOD
    fact_inv = np.empty(N, np.int64)
    fact_inv[N - 1] = mpow(fact[N - 1], MOD - 2)
    for n in range(N - 1, 0, -1):
        fact_inv[n - 1] = fact_inv[n] * n % MOD
    inv = np.empty(N, np.int64)
    inv[0] = 0
    inv[1:] = fact[:-1] * fact_inv[1:] % MOD
    return fact, fact_inv, inv

@njit((i8, ) * 6, cache=True)
def main(H, W, x1, y1, x2, y2):
    if x1 > x2:
        x1, x2 = x2, x1
    if y1 > y2:
        y1, y2 = y2, y1

    fact, fact_inv, inv = fact_table()

    def P(n, k):
        return fact[4 - k] * fact[n - 4] % MOD * fact_inv[n - k] % MOD

    def mult(f, g):
        shape = (f.shape[0] + g.shape[0] - 1, f.shape[1] + g.shape[1] - 1)
        h = np.zeros(shape, np.int64)
        for i in range(len(f)):
            for j in range(len(g)):
                h[i + j] += ntt_convolve(roots, iroots, f[i], g[j])
        return h % MOD

    n = x2 - x1 + 2 + y2 - y1 + 2
    f = np.zeros((1, n + 1), np.int64)
    f[0, n] = 1

    for N in [x1 - 1, H - x2, y1 - 1, W - y2]:
        shape = (f.shape[0] + 1, f.shape[1] + N + 1)
        h = np.zeros(shape, np.int64)
        for i in range(f.shape[0]):
            a = f.shape[1]
            # (1-x)x^Ny 倍
            h[i + 1, N:a + N] += f[i]
            h[i + 1, N + 1:a + N + 1] -= f[i]
            # (1-x^N) 倍
            h[i, :a] += f[i]
            h[i, N:N + a] -= f[i]
        # 1-x で割る
        for i in range(len(h)):
            h[i] = np.cumsum(h[i])
        f = h[:, :-1] % MOD

    K, N = f.shape
    ans = 0
    for k in range(K):
        for n in range(N):
            ans += P(n, k) * f[k, n] % MOD
    return ans % MOD

print(main(*map(int, read().split())))

Submission Info

Submission Time
Task E - Paper Cutting 2
User maspy
Language Python (3.8.2)
Score 700
Code Size 2463 Byte
Status AC
Exec Time 567 ms
Memory 144748 KiB

Judge Result

Set Name Sample All
Score / Max Score 0 / 0 700 / 700
Status
AC × 4
AC × 21
Set Name Test Cases
Sample s1.txt, s2.txt, s3.txt, s4.txt
All 01.txt, 02.txt, 03.txt, 04.txt, 05.txt, 06.txt, 07.txt, 08.txt, 09.txt, 10.txt, 11.txt, 12.txt, 13.txt, 14.txt, 15.txt, 16.txt, 17.txt, s1.txt, s2.txt, s3.txt, s4.txt
Case Name Status Exec Time Memory
01.txt AC 536 ms 139484 KiB
02.txt AC 540 ms 139892 KiB
03.txt AC 512 ms 140740 KiB
04.txt AC 519 ms 139320 KiB
05.txt AC 513 ms 139736 KiB
06.txt AC 512 ms 139728 KiB
07.txt AC 543 ms 140180 KiB
08.txt AC 549 ms 144440 KiB
09.txt AC 547 ms 143652 KiB
10.txt AC 538 ms 141940 KiB
11.txt AC 531 ms 139160 KiB
12.txt AC 546 ms 140120 KiB
13.txt AC 534 ms 140660 KiB
14.txt AC 537 ms 139812 KiB
15.txt AC 551 ms 143580 KiB
16.txt AC 549 ms 144316 KiB
17.txt AC 567 ms 144748 KiB
s1.txt AC 527 ms 139724 KiB
s2.txt AC 516 ms 140108 KiB
s3.txt AC 512 ms 139980 KiB
s4.txt AC 512 ms 139296 KiB