提出 #15780907


ソースコード 拡げる

import sys
import numpy as np
import numba
from numba import njit, b1, i4, i8, f8

read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
readlines = sys.stdin.buffer.readlines

MOD = 200_003
R = 2  # primitive root

@njit((i8[:], ), cache=True)
def precompute(A):
    exp = np.zeros(MOD - 1, np.int64)
    log = np.zeros(MOD, np.int64)
    exp[0] = 1
    for k in range(1, MOD - 1):
        exp[k] = exp[k - 1] * R % MOD
        log[exp[k]] = k
    # r^k の個数を B[k] に入れる
    B = np.zeros(MOD - 1, np.int64)
    for x in A:
        if x != 0:
            B[log[x]] += 1
    return exp, log, B

@njit((i8[:], i8[:], i8[:], i8[:]), cache=True)
def get_ans(A, C, exp, log):
    # 順序対について集計
    ans = 0
    for i in range(len(C)):
        x = exp[i % (MOD - 1)]
        ans += x * C[i]
    for a in A:
        x = a * a % MOD
        ans -= x
    return ans // 2

A = np.array(read().split(), np.int64)[1:]

exp, log, B = precompute(A)
fft, ifft = np.fft.fft, np.fft.ifft
fft_len = 1 << 20
FB = fft(B, fft_len)
C = np.rint(ifft(FB * FB, fft_len)).astype(np.int64)

print(get_ans(A, C, exp, log))

提出情報

提出日時
問題 C - Product Modulo
ユーザ maspy
言語 Python (3.8.2)
得点 800
コード長 1182 Byte
結果 AC
実行時間 705 ms
メモリ 194988 KiB

ジャッジ結果

セット名 Sample All
得点 / 配点 0 / 0 800 / 800
結果
AC × 2
AC × 15
セット名 テストケース
Sample s1.txt, s2.txt
All 001.txt, 002.txt, 003.txt, 004.txt, 005.txt, 006.txt, 007.txt, 008.txt, 009.txt, 010.txt, 011.txt, 012.txt, 013.txt, s1.txt, s2.txt
ケース名 結果 実行時間 メモリ
001.txt AC 662 ms 193048 KiB
002.txt AC 672 ms 192444 KiB
003.txt AC 625 ms 192620 KiB
004.txt AC 693 ms 194072 KiB
005.txt AC 683 ms 194648 KiB
006.txt AC 689 ms 193728 KiB
007.txt AC 705 ms 194048 KiB
008.txt AC 681 ms 194004 KiB
009.txt AC 662 ms 193528 KiB
010.txt AC 701 ms 193952 KiB
011.txt AC 694 ms 193984 KiB
012.txt AC 688 ms 194988 KiB
013.txt AC 629 ms 192688 KiB
s1.txt AC 632 ms 192504 KiB
s2.txt AC 620 ms 191964 KiB