提出 #23299615


ソースコード 拡げる

import sys
import numpy as np
import numba
from numba import njit, b1, i1, i4, i8, f8

read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
readlines = sys.stdin.buffer.readlines

def from_read(dtype=np.int64):
    return np.fromstring(read().decode(), dtype=dtype, sep=' ')


def from_readline(dtype=np.int64):
    return np.fromstring(readline().decode(), dtype=dtype, sep=' ')

@njit
def to_undirected(G, add_index=False):
    N = len(G)
    if add_index:
        G = np.append(G, np.arange(N).reshape(N, 1), axis=1)
    G = np.vstack((G, G))
    G[N:, 0] = G[:N, 1]
    G[N:, 1] = G[:N, 0]
    key = G[:, 0] << 32 | G[:, 1]
    idx = np.argsort(key, kind='mergesort')
    return G[idx]

@njit
def construct_graph(query):
    Q = len(query)
    C = np.zeros(Q, np.int64)
    parent = np.zeros(Q, np.int64)
    G, g = np.empty((Q, 2), np.int64), 0
    parent[0] = -1
    order = np.arange(Q, dtype=np.int64)
    n = 1
    for q in range(Q):
        t, a, c = query[q]
        if t == 1:
            v = n
            n += 1
            parent[v] = a
            G[g], g = (a, v), g + 1
            C[v] = c
            query[q] = (t, v, c)
        elif t == 2:
            query[q] = (t, a, c - C[a])
            C[a] = c
    N = n
    G = to_undirected(G[:g])
    idx = np.searchsorted(G[:, 0], np.arange(N + 1))
    _TO = G[:, 1].copy()
    order = order[:g]
    return N, _TO, idx, parent

INF = 1 << 60


@njit
def add_range(seg, l, r, v):
    N = len(seg) // 2
    l, r = l + N, r + N
    seg[l] += v
    i = l
    while i > 1:
        i >>= 1
        s = seg[i << 1, 0] + seg[i << 1 | 1, 0]
        x = max(seg[i << 1, 1], seg[i << 1, 0] + seg[i << 1 | 1, 1])
        seg[i] = (s, x)
    if r == len(seg):
        return
    seg[r] -= v
    i = r
    while i > 1:
        i >>= 1
        s = seg[i << 1, 0] + seg[i << 1 | 1, 0]
        x = max(seg[i << 1, 1], seg[i << 1, 0] + seg[i << 1 | 1, 1])
        seg[i] = (s, x)


@njit
def get_range(seg, l, r):
    assert l < r
    N = len(seg) // 2
    l, r = l + N, r + N
    l0 = l
    vl = (0, -INF)
    vr = (0, -INF)
    while l < r:
        if l & 1:
            s = vl[0] + seg[l, 0]
            x = max(vl[1], vl[0] + seg[l, 1])
            vl = (s, x)
            l += 1
        if r & 1:
            r -= 1
            s = seg[r, 0] + vr[0]
            x = max(seg[r, 1], seg[r, 0] + vr[1])
            vr = (s, x)
        l, r = l >> 1, r >> 1
    v = max(vl[1], vl[0] + vr[1])
    l, r = N, l0
    while l < r:
        if l & 1:
            v += seg[l, 0]
            l += 1
        if r & 1:
            r -= 1
            v += seg[r, 0]
        l, r = l >> 1, r >> 1
    return v

@njit((i8[:, :], ), cache=True)
def main(query):
    """
    PAR はもとの木での親。辺番号から変更対象辺を見るときに使う。
    """
    N, _TO, idx, PAR = construct_graph(query[:, :3])
    ng = np.zeros(N, np.bool_)

    def TO(v):
        return _TO[idx[v]:idx[v + 1]]

    par = np.zeros(N, np.int64)
    root = np.zeros(N, np.int64)
    vis = np.zeros(N, np.bool_)
    size = np.ones(N, np.int64)
    LID = np.zeros(N, np.int64)
    RID = np.zeros(N, np.int64)
    ANS = np.zeros(len(query), np.int64)

    def euler_tour(n, v0):
        s = 0
        st = np.empty(n, np.int64)
        V, _n = np.empty(n, np.int64), 0
        st[s], s = v0, s + 1
        vis[v0] = 1
        par[v0] = -1
        while s:
            v, s = st[s - 1], s - 1
            V[_n], _n = v, _n + 1
            for w in TO(v):
                if ng[w] or vis[w]:
                    continue
                vis[w] = 1
                par[w] = v
                st[s], s = w, s + 1
        assert _n == n
        center = -1
        for v in V[::-1]:
            if center == -1 and size[v] >= (n + 1) // 2:
                center = v
            if v != v0:
                size[par[v]] += size[v]
        assert center != -1
        for i in range(n):
            v = V[i]
            LID[v] = i
            RID[v] = i + size[v]
        for v in V:
            vis[v] = 0
            size[v] = 1
        return V, center

    def centroid_decmop(n, v0):
        V, center = euler_tour(n, v0)
        V, _ = euler_tour(n, center)
        for w in TO(V[0]):
            if ng[w]:
                continue
            for i in range(LID[w], RID[w]):
                root[V[i]] = w
        return V

    def solve(V, query):
        # print(V)
        # print(query)
        N = len(V)
        Q = len(query)
        seg = np.zeros((N + N, 2), np.int64)
        for q in range(Q):
            t, a, c, q_id = query[q]

            if t == 1 or t == 2:
                b = PAR[a]
                # a, b の間に辺を張る。
                if LID[a] >= N or V[LID[a]] != a:
                    continue
                if LID[b] >= N or V[LID[b]] != b:
                    continue
                if LID[a] < LID[b]:
                    a, b = b, a
                # print('add_subtree', a)
                # seg[LID[a]:RID[a]] += c
                add_range(seg, LID[a], RID[a], c)
            else:
                if a == V[0]:
                    # x = seg.max()
                    x = get_range(seg, 0, N)
                else:
                    # print('calc_frm', a)
                    # print(seg)
                    # x0 = seg[LID[a]]
                    x0 = get_range(seg, LID[a], LID[a] + 1)
                    v = root[a]
                    l, r = LID[v], RID[v]
                    # x1 = seg[0:l].max()
                    x1 = get_range(seg, 0, l)
                    if r < N:
                        # x1 = max(x1, seg[r:N].max())
                        x1 = max(x1, get_range(seg, r, N))
                    x = x0 + x1
                # print(q_id, x)
                ANS[q_id] = max(ANS[q_id], x)

    query_all = query
    stack = [(N, 0, query_all)]
    while stack:
        n, v0, query = stack.pop()
        if n == 1:
            continue
        V = centroid_decmop(n, v0)
        # print(V)
        solve(V, query)
        ng[V[0]] = 1

        # query を分ける
        Q = len(query)
        query_rt = np.empty(Q, np.int64)
        for q in range(Q):
            v = query[q, 1]
            if v == V[0]:
                query_rt[q] = -1
            else:
                query_rt[q] = root[v]
        argsort = np.argsort(query_rt, kind='mergesort')  # stable
        query = query[argsort]
        query_rt = query_rt[argsort]
        for v in TO(V[0]):
            if not ng[v]:
                ql, qr = np.searchsorted(query_rt, [v, v + 1])
                stack.append((RID[v] - LID[v], v, query[ql:qr]))
    ANS = ANS[query_all[:, 0] == 3]
    return ANS

query = from_read()[1:].reshape(-1, 3)
query = np.append(query, np.arange(len(query)).reshape(-1, 1), axis=1)

ans = main(query)
print('\n'.join(map(str, ans.tolist())))

提出情報

提出日時
問題 J - 仕事をしよう! (Working!)
ユーザ maspy
言語 Python (3.8.2)
得点 160
コード長 7093 Byte
結果 AC
実行時間 2840 ms
メモリ 201312 KiB

ジャッジ結果

セット名 Sample Subtask1 Subtask2
得点 / 配点 0 / 0 5 / 5 155 / 155
結果
AC × 3
AC × 9
AC × 18
セット名 テストケース
Sample 0-sample-01.txt, 0-sample-02.txt, 0-sample-03.txt
Subtask1 0-sample-01.txt, 0-sample-02.txt, 0-sample-03.txt, 1-random-01.txt, 1-random-02.txt, 1-random-03.txt, 1-random-04.txt, 1-random-05.txt, 1-random-06.txt
Subtask2 0-sample-01.txt, 0-sample-02.txt, 0-sample-03.txt, 1-random-01.txt, 1-random-02.txt, 1-random-03.txt, 1-random-04.txt, 1-random-05.txt, 1-random-06.txt, 2-random-01.txt, 2-random-02.txt, 2-random-03.txt, 2-random-04.txt, 2-random-05.txt, 2-random-06.txt, 2-special-01.txt, 2-special-02.txt, 2-special-03.txt
ケース名 結果 実行時間 メモリ
0-sample-01.txt AC 528 ms 108400 KiB
0-sample-02.txt AC 507 ms 109036 KiB
0-sample-03.txt AC 503 ms 108392 KiB
1-random-01.txt AC 507 ms 108312 KiB
1-random-02.txt AC 503 ms 108368 KiB
1-random-03.txt AC 507 ms 109048 KiB
1-random-04.txt AC 529 ms 110600 KiB
1-random-05.txt AC 529 ms 109752 KiB
1-random-06.txt AC 520 ms 109844 KiB
2-random-01.txt AC 738 ms 121500 KiB
2-random-02.txt AC 828 ms 124572 KiB
2-random-03.txt AC 687 ms 124448 KiB
2-random-04.txt AC 2059 ms 174472 KiB
2-random-05.txt AC 2840 ms 188016 KiB
2-random-06.txt AC 1722 ms 184988 KiB
2-special-01.txt AC 2347 ms 180148 KiB
2-special-02.txt AC 998 ms 201312 KiB
2-special-03.txt AC 1218 ms 177272 KiB