Submission #17757495


Source Code Expand

Copy
import sys
import numpy as np
import numba
from numba import njit, b1, i4, i8, f8

read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
readlines = sys.stdin.buffer.readlines

def from_read(dtype=np.int64):
    return np.fromstring(read().decode(), dtype=dtype, sep=' ')


def from_readline(dtype=np.int64):
    return np.fromstring(readline().decode(), dtype=dtype, sep=' ')

"""def naive(col, n):
    def mex(a,b):
        for i in range(3):
            if a == i or b == i:
                continue
            return i
    dp = np.zeros_like(col)
    x = n
    for i in range(len(col)):
        dp[i] = mex(col[i], x)
        x = dp[i]
    return dp
col = np.random.randint(0, 3, 30)
for _ in range(10):
    print(col)
    x = np.random.randint(0, 3)
    col = naive(col, x)"""

@njit((i8, i8), cache=True)
def mex(a, b):
    for i in range(3):
        if a == i or b == i:
            continue
        return i

@njit
def update_row(row, x):

    dp = np.zeros_like(row)
    for i in range(len(row)):
        dp[i] = mex(row[i], x)
        x = dp[i]
    return dp

@njit
def solve_small(N, A):
    ans = np.zeros(3, np.int64)
    ans[A[0]] += 1
    row = A[1:N]
    col = A[N:]
    for x in row:
        ans[x] += 1
    for x in col:
        ans[x] += 1
    for i in range(N - 1):
        row = update_row(row, col[i])
        for x in row:
            ans[x] += 1
    return ans

@njit((i8, i8[:]), cache=True)
def main(N, A):
    if N <= 3:
        return solve_small(N, A)
    # (1,1) セルは茶番
    ans = np.zeros(3, np.int64)
    ans[A[0]] += 1
    row = A[1:N]
    col = A[N:]
    for x in row:
        ans[x] += 1
    for x in col:
        ans[x] += 1
    # とりあえず、10 回くらい進める
    for i in range(3):
        row = update_row(row, col[i])
        for x in row:
            ans[x] += 1
    col = col[3:]
    K = len(col)
    # のこり K 行の計算がある。
    for i in range(2, N - 1):
        # いま入っている数はそのまま右下に移動する
        # 集計回数は?
        k = (N - 2) - i
        k = min(k, K)
        ans[row[i]] += k
    # 左 3 つの数値のみ計算していく
    row = row[:3]
    for i in range(K):
        row = update_row(row, col[i])
        a, b, c = row
        ans[a] += 1
        ans[b] += 1
        ans[c] += K - i
    return ans

N = int(readline())
A = from_read()

print(*main(N, A))

Submission Info

Submission Time
Task E - Mex Mat
User maspy
Language Python (3.8.2)
Score 800
Code Size 2508 Byte
Status AC
Exec Time 557 ms
Memory 124256 KB

Judge Result

Set Name Sample All
Score / Max Score 0 / 0 800 / 800
Status
AC × 1
AC × 31
Set Name Test Cases
Sample example_00
All ex_small_00, ex_small_01, ex_small_02, ex_small_03, ex_small_04, ex_small_05, ex_small_06, ex_small_07, ex_small_08, ex_small_09, example_00, max_random_00, max_random_01, max_random_02, max_random_03, max_random_04, max_random_05, max_random_06, max_random_07, max_random_08, max_random_09, small_00, small_01, small_02, small_03, small_04, small_05, small_06, small_07, small_08, small_09
Case Name Status Exec Time Memory
ex_small_00 AC 499 ms 106764 KB
ex_small_01 AC 476 ms 105672 KB
ex_small_02 AC 479 ms 106124 KB
ex_small_03 AC 485 ms 106076 KB
ex_small_04 AC 478 ms 105596 KB
ex_small_05 AC 480 ms 106116 KB
ex_small_06 AC 478 ms 106784 KB
ex_small_07 AC 478 ms 106908 KB
ex_small_08 AC 477 ms 106124 KB
ex_small_09 AC 479 ms 106756 KB
example_00 AC 477 ms 106108 KB
max_random_00 AC 542 ms 120456 KB
max_random_01 AC 557 ms 123124 KB
max_random_02 AC 485 ms 108012 KB
max_random_03 AC 551 ms 122632 KB
max_random_04 AC 523 ms 116500 KB
max_random_05 AC 534 ms 117664 KB
max_random_06 AC 557 ms 124256 KB
max_random_07 AC 489 ms 108516 KB
max_random_08 AC 521 ms 115920 KB
max_random_09 AC 497 ms 109428 KB
small_00 AC 479 ms 106148 KB
small_01 AC 477 ms 106140 KB
small_02 AC 477 ms 106104 KB
small_03 AC 479 ms 106012 KB
small_04 AC 477 ms 105600 KB
small_05 AC 478 ms 106780 KB
small_06 AC 477 ms 106328 KB
small_07 AC 480 ms 106908 KB
small_08 AC 478 ms 105624 KB
small_09 AC 479 ms 105676 KB