提出 #19430090


ソースコード 拡げる

Copy
import sys
input = lambda : sys.stdin.readline().rstrip()

sys.setrecursionlimit(2*10**5+10)
write = lambda x: sys.stdout.write(x+"\n")
_print = lambda *x: print(*x, file=sys.stderr)

n = int(input())
g = [list(map(int, input().split())) for _ in range(n)]
ns = [[] for _ in range(n)]
for i in range(n):
    for j in range(n):
        if g[i][j]:
            ns[i].append(j)
def scc(ns):
    n = len(ns)
    l = []
    vmin = [n]*n
    vnum = [-1]*n
    val = 0
    wait = []
    waiting = [False]*n
    seen = [[] for _ in range(n)]
    unseen = [[] for _ in range(n)]
    for u in range(n):
        if vnum[u]!=-1:
            continue
        q = [u]
        while q:
            u = q.pop()
            if u<0:
                u = ~u
                mm = vmin[u]
                for v in seen[u]:
                    mm = min(mm, vmin[v])
                for v in unseen[u]:
                    mm = min(mm, vnum[v])
                vmin[u] = mm
                if mm==vnum[u]:
                    ll = []
                    while True:
                        v = wait.pop()
                        waiting[v] = False
                        ll.append(v)
                        if u==v:
                            break
                    l.append(ll)
            elif vnum[u]!=-1:
                continue
            else:
                q.append(~u)
                wait.append(u)
                waiting[u] = True
                vnum[u] = vmin[u] = val
                val += 1
                for v in ns[u]:
                    if vnum[v]==-1:
                        q.append(v)
                        seen[u].append(v)
                    elif waiting[v]:
                        unseen[u].append(v)
    return l
l = scc(ns)
m = len(l)
nns = [set() for _ in range(m)]
cs = [None]*n
vs = [0]*m
for i in range(m):
    vs[i] = len(l[i])
    for u in l[i]:
        cs[u] = i
for u in range(n):
    ui = cs[u]
    for v in ns[u]:
        vi = cs[v]
        if ui!=vi:
            nns[ui].add(vi)
# nns: DAG
import networkx as nx
g = nx.MultiDiGraph()
for i in range(m):
    g.add_node(i)
    g.add_node(i+m)
    g.add_edge(i,i+m,capacity=1,weight=-vs[i])
    g.add_edge(i,i+m,capacity=1,weight=0)
for i in range(m):
    for j in nns[i]:
        g.add_edge(i+m,j)
s = 2*m
t = s+1
g.add_node(s,demand=-2)
g.add_node(t,demand=2)
for i in range(m):
    g.add_edge(s,i)
    g.add_edge(i+m,t)
val, d = nx.network_simplex(g)
print(-val)

提出情報

提出日時
問題 R - グラフ
ユーザ shotoyoo
言語 Python (3.8.2)
得点 7
コード長 2540 Byte
結果 AC
実行時間 816 ms
メモリ 72060 KB

ジャッジ結果

セット名 All
得点 / 配点 7 / 7
結果
AC × 22
セット名 テストケース
All 00, 01, 02, 03, 04, 05, 06, 07, 08, 09, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 90, 91
ケース名 結果 実行時間 メモリ
00 AC 816 ms 72060 KB
01 AC 375 ms 55436 KB
02 AC 377 ms 55536 KB
03 AC 381 ms 55376 KB
04 AC 375 ms 55396 KB
05 AC 371 ms 55296 KB
06 AC 361 ms 54192 KB
07 AC 353 ms 54012 KB
08 AC 350 ms 54076 KB
09 AC 352 ms 53968 KB
10 AC 364 ms 55472 KB
11 AC 380 ms 55480 KB
12 AC 407 ms 56268 KB
13 AC 415 ms 56732 KB
14 AC 422 ms 57184 KB
15 AC 437 ms 57560 KB
16 AC 454 ms 58032 KB
17 AC 456 ms 58640 KB
18 AC 465 ms 58884 KB
19 AC 474 ms 59248 KB
90 AC 328 ms 52700 KB
91 AC 328 ms 52964 KB