Submission #16920107


Source Code Expand

Copy
import sys
import numpy as np
import numba
from numba import njit, b1, i4, i8, f8

read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
readlines = sys.stdin.buffer.readlines

uf_t = numba.types.UniTuple(i8[:], 2)


@njit((uf_t, i8), cache=True)
def find_root(uf, x):
    root = uf[0]
    while root[x] != x:
        root[x] = root[root[x]]
        x = root[x]
    return x


@njit((uf_t, i8, i8), cache=True)
def merge(uf, x, y):
    root, size = uf
    x, y = find_root(uf, x), find_root(uf, y)
    if x == y:
        return False
    if size[x] < size[y]:
        x, y = y, x
    size[x] += size[y]
    root[y] = root[x]
    return True

@njit((i8[:], ), cache=True)
def build(raw_data):
    bit = raw_data.copy()
    for i in range(len(bit)):
        j = i + (i & (-i))
        if j < len(bit):
            bit[j] += bit[i]
    return bit


@njit((i8[:], i8), cache=True)
def get_sum(bit, i):
    s = 0
    while i:
        s += bit[i]
        i -= i & -i
    return s


@njit((i8[:], i8, i8), cache=True)
def add(bit, i, x):
    while i < len(bit):
        bit[i] += x
        i += i & -i


@njit((i8[:], i8), cache=True)
def find_kth_element(bit, k):
    N = len(bit)
    x, sx = 0, 0
    dx = 1
    while 2 * dx < N:
        dx *= 2
    while dx:
        y = x + dx
        if y < N:
            sy = sx + bit[y]
            if sy < k:
                x, sx = y, sy
        dx //= 2
    return x + 1

@njit((i8[:], ), cache=True)
def main(XY):
    N = len(XY) // 2
    root = np.arange(N + 1, dtype=np.int64)
    size = np.ones_like(root)
    uf = (root, size)

    XtoY = np.zeros(N + 1, np.int64)
    YtoX = np.zeros(N + 1, np.int64)
    for i in range(0, N + N, 2):
        x, y = XY[i:i + 2]
        XtoY[x] = y
        YtoX[y] = x

    # 残してある y 座標集合を bit で管理
    bit_raw = np.ones(N + 1, np.int64)
    bit_raw[0] = 0
    bit = build(bit_raw)
    rest_Y = np.ones(N + 1, np.int64)
    rest_Y_cnt = N
    for x in range(1, N + 1):
        y = XtoY[x]
        if not rest_Y[y]:
            continue
        k = get_sum(bit, y)
        largest_x = x
        for i in range(k + 1, rest_Y_cnt + 1):
            y1 = find_kth_element(bit, i)
            x1 = YtoX[y1]
            merge(uf, x, x1)
            largest_x = max(largest_x, x1)
        for i in range(k + 1, rest_Y_cnt + 1):
            y1 = find_kth_element(bit, k + 1)
            x1 = YtoX[y1]
            if x1 != largest_x:
                rest_Y[y1] = 0
                rest_Y_cnt -= 1
                add(bit, y1, -1)
        rest_Y[y] = 0
        rest_Y_cnt -= 1
        add(bit, y, -1)
    for i in range(0, N + N, 2):
        x, y = XY[i:i + 2]
        rx = find_root(uf, x)
        print(size[rx])

N = int(readline())
XY = np.array(read().split(), np.int64)

main(XY)

Submission Info

Submission Time
Task A - Reachable Towns
User maspy
Language Python (3.8.2)
Score 300
Code Size 2910 Byte
Status
Exec Time 734 ms
Memory 135068 KB

Judge Result

Set Name Sample All
Score / Max Score 0 / 0 300 / 300
Status
× 2
× 24
Set Name Test Cases
Sample example_00, example_01
All example_00, example_01, manyperm_00, manyperm_01, manyperm_02, manyperm_03, max_random_00, max_random_01, random_00, random_01, small_00, small_01, small_02, small_03, small_04, small_05, small_06, small_07, small_08, small_09, special1_00, special1_01, special1_02, special1_03
Case Name Status Exec Time Memory
example_00 509 ms 106776 KB
example_01 493 ms 106644 KB
manyperm_00 730 ms 131904 KB
manyperm_01 733 ms 135068 KB
manyperm_02 734 ms 132064 KB
manyperm_03 734 ms 132560 KB
max_random_00 732 ms 131816 KB
max_random_01 703 ms 131404 KB
random_00 696 ms 125276 KB
random_01 664 ms 127524 KB
small_00 489 ms 106372 KB
small_01 487 ms 107048 KB
small_02 497 ms 106360 KB
small_03 487 ms 106280 KB
small_04 495 ms 106776 KB
small_05 494 ms 106272 KB
small_06 492 ms 106700 KB
small_07 492 ms 107412 KB
small_08 495 ms 107412 KB
small_09 491 ms 106948 KB
special1_00 618 ms 124512 KB
special1_01 648 ms 128184 KB
special1_02 538 ms 113476 KB
special1_03 660 ms 130252 KB