Submission #19430084


Source Code Expand

Copy
import sys
input = lambda : sys.stdin.readline().rstrip()

sys.setrecursionlimit(2*10**5+10)
write = lambda x: sys.stdout.write(x+"\n")
_print = lambda *x: print(*x, file=sys.stderr)

n = int(input())
g = [list(map(int, input().split())) for _ in range(n)]
ns = [[] for _ in range(n)]
for i in range(n):
    for j in range(n):
        if g[i][j]:
            ns[i].append(j)
def scc(ns):
    n = len(ns)
    l = []
    vmin = [n]*n
    vnum = [-1]*n
    val = 0
    wait = []
    waiting = [False]*n
    seen = [[] for _ in range(n)]
    unseen = [[] for _ in range(n)]
    for u in range(n):
        if vnum[u]!=-1:
            continue
        q = [u]
        while q:
            u = q.pop()
            if u<0:
                u = ~u
                mm = vmin[u]
                for v in seen[u]:
                    mm = min(mm, vmin[v])
                for v in unseen[u]:
                    mm = min(mm, vnum[v])
                vmin[u] = mm
                if mm==vnum[u]:
                    ll = []
                    while True:
                        v = wait.pop()
                        waiting[v] = False
                        ll.append(v)
                        if u==v:
                            break
                    l.append(ll)
            elif vnum[u]!=-1:
                continue
            else:
                q.append(~u)
                wait.append(u)
                waiting[u] = True
                vnum[u] = vmin[u] = val
                val += 1
                for v in ns[u]:
                    if vnum[v]==-1:
                        q.append(v)
                        seen[u].append(v)
                    elif waiting[v]:
                        unseen[u].append(v)
    return l
l = scc(ns)
m = len(l)
nns = [set() for _ in range(m)]
cs = [None]*n
vs = [0]*m
for i in range(m):
    vs[i] = len(l[i])
    for u in l[i]:
        cs[u] = i
for u in range(n):
    ui = cs[u]
    for v in ns[u]:
        vi = cs[v]
        if ui!=vi:
            nns[ui].add(vi)
# nns: DAG
import networkx as nx
g = nx.MultiDiGraph()
for i in range(m):
    g.add_node(i)
    g.add_node(i+m)
    g.add_edge(i,i+m,capacity=1,weight=-vs[i])
    g.add_edge(i,i+m,capacity=1,weight=0)
for i in range(m):
    for j in nns[i]:
        g.add_edge(i+m,j)
s = 2*m
t = s+1
g.add_node(s,demand=-2)
g.add_node(t,demand=2)
for i in range(m):
    g.add_edge(s,i)
    g.add_edge(i+m,t)
val, d = nx.network_simplex(g)
print(-val)

Submission Info

Submission Time
Task R - グラフ
User shotoyoo
Language PyPy3 (7.3.0)
Score 0
Code Size 2540 Byte
Status RE
Exec Time 120 ms
Memory 69924 KB

Judge Result

Set Name All
Score / Max Score 0 / 7
Status
RE × 22
Set Name Test Cases
All 00, 01, 02, 03, 04, 05, 06, 07, 08, 09, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 90, 91
Case Name Status Exec Time Memory
00 RE 120 ms 69492 KB
01 RE 92 ms 69672 KB
02 RE 99 ms 69632 KB
03 RE 96 ms 69868 KB
04 RE 97 ms 69588 KB
05 RE 98 ms 69732 KB
06 RE 98 ms 69732 KB
07 RE 96 ms 69544 KB
08 RE 95 ms 69780 KB
09 RE 94 ms 69824 KB
10 RE 93 ms 69504 KB
11 RE 94 ms 69772 KB
12 RE 95 ms 69924 KB
13 RE 101 ms 69536 KB
14 RE 99 ms 69860 KB
15 RE 97 ms 69712 KB
16 RE 94 ms 69596 KB
17 RE 95 ms 69484 KB
18 RE 98 ms 69876 KB
19 RE 101 ms 69580 KB
90 RE 80 ms 67868 KB
91 RE 78 ms 67836 KB