提出 #16874853


ソースコード 拡げる

import sys
import numpy as np
import numba
from numba import njit, b1, i4, i8, f8

read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
readlines = sys.stdin.buffer.readlines

@njit((numba.types.optional(i8), ) * 2, cache=True)
def seg_f(x, y):
    if x is None:
        return y
    if y is None:
        return x
    return min(x, y)


@njit((i8[:], i8[:]), cache=True)
def build(seg, raw_data):
    N = len(seg) // 2
    seg[N:] = raw_data
    for i in range(N - 1, 0, -1):
        seg[i] = seg_f(seg[i << 1], seg[i << 1 | 1])


@njit((i8[:], i8, i8), cache=True)
def set_val(seg, i, x):
    N = len(seg) // 2
    i += N
    seg[i] = x
    while i > 1:
        i >>= 1
        seg[i] = seg_f(seg[i << 1], seg[i << 1 | 1])


@njit((i8[:], i8, i8), cache=True)
def fold(seg, l, r):
    vl = vr = None
    N = len(seg) // 2
    l, r = l + N, r + N
    while l < r:
        if l & 1:
            vl = seg_f(vl, seg[l])
            l += 1
        if r & 1:
            r -= 1
            vr = seg_f(seg[r], vr)
        l, r = l >> 1, r >> 1
    return seg_f(vl, vr)

@njit((i8, i8[:]), cache=True)
def main(N, TX):
    dp1 = np.full(N, N - 2, np.int64)
    dp2 = np.full(N, N - 2, np.int64)
    seg1 = np.full(N + N, N - 2, np.int64)
    seg2 = np.full(N + N, N - 2, np.int64)
    filled1 = np.zeros(N, np.int64)
    filled2 = np.zeros(N, np.int64)

    ans = 0
    for i in range(0, len(TX), 2):
        t, x = TX[i:i + 2]
        x -= 2
        if t == 1 and filled1[x]:
            continue
        if t == 2 and filled2[x]:
            continue
        if t == 1:
            n = fold(seg2, x, N)
            ans += n
            if n:
                dp1[n - 1] = min(dp1[n - 1], x)
                set_val(seg1, n - 1, dp1[n - 1])
        elif t == 2:
            n = fold(seg1, x, N)
            ans += n
            if n:
                dp2[n - 1] = min(dp2[n - 1], x)
                set_val(seg2, n - 1, dp2[n - 1])
    return (N - 2) * (N - 2) - ans

N, Q = map(int, readline().split())
TX = np.array(read().split(), np.int64)

print(main(N, TX))

提出情報

提出日時
問題 F - Simplified Reversi
ユーザ maspy
言語 Python (3.8.2)
得点 600
コード長 2154 Byte
結果 AC
実行時間 593 ms
メモリ 124936 KiB

ジャッジ結果

セット名 Sample All
得点 / 配点 0 / 0 600 / 600
結果
AC × 3
AC × 21
セット名 テストケース
Sample sample_01.txt, sample_02.txt, sample_03.txt
All hand_01.txt, hand_02.txt, hand_03.txt, random_01.txt, random_02.txt, random_03.txt, random_04.txt, random_05.txt, random_06.txt, random_07.txt, random_08.txt, random_09.txt, random_10.txt, random_11.txt, random_12.txt, random_13.txt, random_14.txt, random_15.txt, sample_01.txt, sample_02.txt, sample_03.txt
ケース名 結果 実行時間 メモリ
hand_01.txt AC 504 ms 106432 KiB
hand_02.txt AC 494 ms 106560 KiB
hand_03.txt AC 483 ms 106344 KiB
random_01.txt AC 593 ms 124236 KiB
random_02.txt AC 504 ms 119036 KiB
random_03.txt AC 585 ms 123104 KiB
random_04.txt AC 529 ms 116828 KiB
random_05.txt AC 581 ms 124936 KiB
random_06.txt AC 561 ms 121176 KiB
random_07.txt AC 589 ms 124184 KiB
random_08.txt AC 507 ms 114412 KiB
random_09.txt AC 586 ms 124192 KiB
random_10.txt AC 481 ms 107092 KiB
random_11.txt AC 554 ms 122692 KiB
random_12.txt AC 539 ms 117520 KiB
random_13.txt AC 573 ms 124088 KiB
random_14.txt AC 561 ms 123620 KiB
random_15.txt AC 561 ms 123412 KiB
sample_01.txt AC 484 ms 106336 KiB
sample_02.txt AC 497 ms 119316 KiB
sample_03.txt AC 501 ms 118052 KiB