Submission #26982349
Source Code Expand
(in-package :cl-user)
(eval-when (:compile-toplevel :load-toplevel :execute)
(defparameter *opt*
#+swank '(optimize (speed 3) (safety 2))
#-swank '(optimize (speed 3) (safety 0) (debug 0)))
#+swank (ql:quickload '(:cl-debug-print :fiveam :cp/util) :silent t)
#+swank (use-package :cp/util :cl-user)
#-swank (set-dispatch-macro-character
#\# #\> (lambda (s c p) (declare (ignore c p)) `(values ,(read s nil nil t))))
#+sbcl (dolist (f '(:popcnt :sse4)) (pushnew f sb-c:*backend-subfeatures*))
(setq *random-state* (make-random-state t)))
#-swank (eval-when (:compile-toplevel)
(setq *break-on-signals* '(and warning (not style-warning))))
#+swank (set-dispatch-macro-character #\# #\> #'cl-debug-print:debug-print-reader)
(macrolet ((def (b)
`(progn (deftype ,(intern (format nil "UINT~A" b)) () '(unsigned-byte ,b))
(deftype ,(intern (format nil "INT~A" b)) () '(signed-byte ,b))))
(define-int-types (&rest bits) `(progn ,@(mapcar (lambda (b) `(def ,b)) bits))))
(define-int-types 2 4 7 8 15 16 31 32 62 63 64))
(defconstant +mod+ 998244353)
(defmacro dbg (&rest forms)
(declare (ignorable forms))
#+swank (if (= (length forms) 1)
`(format *error-output* "~A => ~A~%" ',(car forms) ,(car forms))
`(format *error-output* "~A => ~A~%" ',forms `(,,@forms))))
(declaim (inline println))
(defun println (obj &optional (stream *standard-output*))
(let ((*read-default-float-format*
(if (typep obj 'double-float) 'double-float *read-default-float-format*)))
(prog1 (princ obj stream) (terpri stream))))
;; BEGIN_INSERTED_CONTENTS
(defpackage :cp/barrett
(:use :cl)
(:import-from #:sb-ext #:truly-the)
(:import-from #:sb-c
#:deftransform
#:derive-type
#:defoptimizer
#:defknown #:movable #:foldable #:flushable #:commutative
#:always-translatable
#:lvar-type
#:lvar-value
#:give-up-ir1-transform
#:integer-type-numeric-bounds
#:define-vop
#:tn-offset)
(:import-from #:sb-vm
#:move #:inst #:rax-offset #:rdx-offset #:temp-reg-tn
#:any-reg #:control-stack #:unsigned-reg
#:positive-fixnum
#:fixnumize #:ea)
(:import-from #:sb-kernel #:specifier-type)
(:import-from #:sb-int #:explicit-check #:constant-arg)
(:export #:fast-mod #:%himod #:%lomod)
(:documentation "Provides Barrett reduction."))
(in-package :cp/barrett)
(eval-when (:compile-toplevel :load-toplevel :execute)
(defun derive-* (x y)
(let ((high1 (nth-value 1 (integer-type-numeric-bounds (lvar-type x))))
(high2 (nth-value 1 (integer-type-numeric-bounds (lvar-type y)))))
(specifier-type (if (and (integerp high1) (integerp high2))
`(integer 0 ,(ash (* high1 high2) -62))
`(integer 0)))))
(defun derive-mod (modulus)
(let ((high (nth-value 1 (integer-type-numeric-bounds (lvar-type modulus)))))
(specifier-type (if (integerp high)
`(integer 0 (,high))
`(integer 0)))))
(defun gpr-tn-p (x)
(declare (ignorable x))
#.(if (find-symbol "GPR-TN-P" :sb-vm)
`(funcall (intern "GPR-TN-P" :sb-vm) x)
t)))
;; *-high62
(eval-when (:compile-toplevel :load-toplevel :execute)
(defknown *-high62 ((unsigned-byte 62) (unsigned-byte 62)) (unsigned-byte 62)
(movable foldable flushable commutative always-translatable)
:overwrite-fndb-silently t)
(defoptimizer (*-high62 derive-type) ((x y))
(derive-* x y))
(define-vop (fast-*-high62/fixnum)
(:translate *-high62)
(:policy :fast-safe)
(:args (x :scs (any-reg) :target rax)
(y :scs (any-reg control-stack)))
(:arg-types positive-fixnum positive-fixnum)
(:temporary (:sc any-reg :offset rax-offset
:from (:argument 0) :to :result)
rax)
(:temporary (:sc any-reg :offset rdx-offset :target r
:from :eval :to :result)
rdx)
(:results (r :scs (any-reg)))
(:result-types positive-fixnum)
(:note "inline *-high62")
(:vop-var vop)
(:save-p :compute-only)
(:generator 6
(move rax x)
(inst mul rax y)
(inst shl rdx 1)
(move r rdx)))
;; NOTE: I'm not using it for now because registered constant is slower
;; (define-vop (fast-c-*-high62-/fixnum)
;; (:translate *-high62)
;; (:policy :fast-safe)
;; (:args (x :scs (any-reg) :target rax))
;; (:info y)
;; (:arg-types positive-fixnum (:constant (unsigned-byte 62)))
;; (:temporary (:sc any-reg :offset rax-offset
;; :from (:argument 0) :to :result)
;; rax)
;; (:temporary (:sc any-reg :offset rdx-offset :target r
;; :from :eval :to :result)
;; rdx)
;; (:results (r :scs (any-reg)))
;; (:result-types positive-fixnum)
;; (:note "inline constant *-high62")
;; (:vop-var vop)
;; (:save-p :compute-only)
;; (:generator 5
;; (move rax x)
;; (inst mul rax (sb-c:register-inline-constant :qword (fixnumize y)))
;; (inst shl rdx 1)
;; (move r rdx)))
(defun *-high62 (x y)
(declare (explicit-check))
(*-high62 x y)))
;; %himod
(eval-when (:compile-toplevel :load-toplevel :execute)
(defknown %himod ((unsigned-byte 32) (unsigned-byte 31)) (unsigned-byte 31)
(movable foldable flushable always-translatable)
:overwrite-fndb-silently t)
(defoptimizer (%himod derive-type) ((integer modulus))
(declare (ignore integer))
(derive-mod modulus))
(define-vop (fast-c-%himod)
(:translate %himod)
(:policy :fast-safe)
(:args (x :scs (any-reg) :target r))
(:info m)
(:arg-types positive-fixnum (:constant (unsigned-byte 31)))
(:temporary (:sc any-reg :from :eval :to :result
;; FIXME: hack to avoid collision of X and Y
:offset #.(tn-offset temp-reg-tn))
y)
(:results (r :scs (any-reg)))
(:result-types positive-fixnum)
(:note "inline constant %himod")
(:vop-var vop)
(:generator
4
;; maybe verbose
(assert (gpr-tn-p x))
(assert (not (sb-c:location= x y)))
(when (sb-c:tn-p m)
(assert (sb-c:sc-is m sb-vm::immediate))
(setq m (sb-c::tn-value m)))
(setq m (fixnumize m))
(move r x)
(inst cmp r (- m 2))
(inst lea y #.(if (fboundp 'ea)
`(sb-vm::ea (- m) r)
`(,(find-symbol "MAKE-EA" :sb-vm) :dword :disp (- m) :base r)))
(inst cmov :a r y))))
;; %lomod
(eval-when (:compile-toplevel :load-toplevel :execute)
(defknown %lomod ((signed-byte 32) (unsigned-byte 31)) (unsigned-byte 31)
(movable foldable flushable always-translatable)
:overwrite-fndb-silently t)
(defoptimizer (%lomod derive-type) ((integer modulus))
(declare (ignore integer))
(derive-mod modulus))
(define-vop (fast-c-%lomod)
(:translate %lomod)
(:policy :fast-safe)
(:args (x :scs (any-reg) :target r))
(:info m)
(:arg-types fixnum (:constant (unsigned-byte 31)))
(:temporary (:sc any-reg :from :eval :to :result :offset #.(tn-offset temp-reg-tn)) y)
(:results (r :scs (any-reg)))
(:result-types positive-fixnum)
(:note "inline constant %lomod")
(:vop-var vop)
(:generator
4
;; maybe verbose
(assert (gpr-tn-p x))
(assert (not (sb-c:location= x y)))
(when (sb-c:tn-p m)
(assert (sb-c:sc-is m sb-vm::immediate))
(setq m (sb-c::tn-value m)))
(setq m (fixnumize m))
(move r x)
;; TODO: this instruction is rarely necessary because %LOMOD tends to be
;; called right after subtraction.
(inst or r r)
(inst lea y #.(if (fboundp 'ea)
`(sb-vm::ea m r)
`(,(find-symbol "MAKE-EA" :sb-vm) :dword :disp m :base r)))
(inst cmov :l r y))))
;; fast-mod
(eval-when (:compile-toplevel :load-toplevel :execute)
(defknown fast-mod (integer unsigned-byte) unsigned-byte
(movable foldable flushable)
:overwrite-fndb-silently t)
(defoptimizer (fast-mod derive-type) ((integer modulus))
(declare (ignore integer))
(derive-mod modulus))
(defun fast-mod (number divisor)
"Is equivalent to CL:MOD for integer arguments"
(declare (explicit-check))
(mod number divisor))
(deftransform fast-mod ((number divisor)
((unsigned-byte 62) (constant-arg (unsigned-byte 31))) *
:important t)
"convert FAST-MOD to barrett reduction"
(let ((mod (lvar-value divisor)))
(if (<= mod 1)
(give-up-ir1-transform)
(let ((m (floor (ash 1 62) mod)))
`(let* ((q (*-high62 number ,m))
(x (truly-the (signed-byte 32) (- number (* q ,mod)))))
(if (< x ,mod) x (- x ,mod))
;; Not so effective because branch prediction is usually correct?
;; (sb-ext:truly-the (mod ,mod) (%himod x ,mod))
))))))
(defpackage :cp/mod-operations
(:use :cl)
(:export #:define-mod-operations #:*modulus*)
(:documentation "Provides modular arithmetic."))
(in-package :cp/mod-operations)
(eval-when (:compile-toplevel :load-toplevel :execute)
(defvar *modulus* 0))
(declaim ((unsigned-byte 31) *modulus*)
#+sbcl (sb-ext:always-bound *modulus*))
(defmacro define-mod-operations
(divisor &optional (package #+sbcl (sb-int:sane-package) #-sbcl *package*))
(let ((mod* (intern "MOD*" package))
(mod+ (intern "MOD+" package))
(mod- (intern "MOD-" package))
(incfmod (intern "INCFMOD" package))
(decfmod (intern "DECFMOD" package))
(mulfmod (intern "MULFMOD" package)))
`(progn
(defun ,mod* (&rest args)
(cond ((cdr args) (reduce (lambda (x y) (mod (* x y) ,divisor)) args))
(args (mod (car args) ,divisor))
(t (mod 1 ,divisor))))
(defun ,mod+ (&rest args)
(cond ((cdr args) (reduce (lambda (x y) (mod (+ x y) ,divisor)) args))
(args (mod (car args) ,divisor))
(t 0)))
(defun ,mod- (&rest args)
(if (cdr args)
(reduce (lambda (x y) (mod (- x y) ,divisor)) args)
(mod (- (car args)) ,divisor)))
#+sbcl
(eval-when (:compile-toplevel :load-toplevel :execute)
(locally (declare (sb-ext:muffle-conditions warning))
(sb-c:define-source-transform ,mod* (&rest args)
(case (length args)
(0 `(mod 1 ,',divisor))
(1 `(mod ,(car args) ,',divisor))
(otherwise (reduce (lambda (x y) `(mod (* ,x ,y) ,',divisor)) args))))
(sb-c:define-source-transform ,mod+ (&rest args)
(case (length args)
(0 0)
(1 `(mod ,(car args) ,',divisor))
(otherwise (reduce (lambda (x y) `(mod (+ ,x ,y) ,',divisor)) args))))
(sb-c:define-source-transform ,mod- (&rest args)
(case (length args)
(0 (values nil t))
(1 `(mod (- ,(car args)) ,',divisor))
(otherwise (reduce (lambda (x y) `(mod (- ,x ,y) ,',divisor)) args))))))
(define-modify-macro ,incfmod (delta)
(lambda (x y) (mod (+ x y) ,divisor)))
(define-modify-macro ,decfmod (delta)
(lambda (x y) (mod (- x y) ,divisor)))
(define-modify-macro ,mulfmod (multiplier)
(lambda (x y) (mod (* x y) ,divisor))))))
(define-mod-operations cl-user::+mod+ :cl-user)
(defpackage :cp/read-fixnum
(:use :cl)
(:export #:read-fixnum))
(in-package :cp/read-fixnum)
(declaim (ftype (function * (values fixnum &optional)) read-fixnum))
(defun read-fixnum (&optional (in *standard-input*))
"NOTE: cannot read -2^62"
(macrolet ((%read-byte ()
`(the (unsigned-byte 8)
#+swank (char-code (read-char in nil #\Nul))
#-swank (sb-impl::ansi-stream-read-byte in nil #.(char-code #\Nul) nil))))
(let* ((minus nil)
(result (loop (let ((byte (%read-byte)))
(cond ((<= 48 byte 57)
(return (- byte 48)))
((zerop byte) ; #\Nul
(error "Read EOF or #\Nul."))
((= byte #.(char-code #\-))
(setq minus t)))))))
(declare ((integer 0 #.most-positive-fixnum) result))
(loop
(let* ((byte (%read-byte)))
(if (<= 48 byte 57)
(setq result (+ (- byte 48)
(* 10 (the (integer 0 #.(floor most-positive-fixnum 10))
result))))
(return (if minus (- result) result))))))))
(defpackage :cp/static-mod
(:use :cl)
(:export #:+mod+))
(in-package :cp/static-mod)
(defconstant +mod+ (if (boundp 'cl-user::+mod+)
(symbol-value 'cl-user::+mod+)
998244353))
(defpackage :cp/binom-mod-prime
(:use :cl :cp/static-mod)
(:export #:binom #:perm #:multinomial #:stirling2 #:catalan #:multichoose
#:*fact* #:*fact-inv* #:*inv*)
(:documentation
"Provides tables of factorials, inverses, inverses ot factorials etc.
modulo prime.
build: O(n)
query: O(1)
"))
(in-package :cp/binom-mod-prime)
;; TODO: non-global handling
(defconstant +binom-size+ 110000)
(declaim ((simple-array (unsigned-byte 31) (*)) *fact* *fact-inv* *inv*))
(sb-ext:define-load-time-global *fact*
(make-array +binom-size+ :element-type '(unsigned-byte 31))
"table of factorials")
(sb-ext:define-load-time-global *fact-inv*
(make-array +binom-size+ :element-type '(unsigned-byte 31))
"table of inverses of factorials")
(sb-ext:define-load-time-global *inv*
(make-array +binom-size+ :element-type '(unsigned-byte 31))
"table of inverses of non-negative integers")
(defun initialize-binom ()
(declare (optimize (speed 3) (safety 0)))
(setf (aref *fact* 0) 1
(aref *fact* 1) 1
(aref *fact-inv* 0) 1
(aref *fact-inv* 1) 1
(aref *inv* 1) 1)
(loop for i from 2 below +binom-size+
do (setf (aref *fact* i) (mod (* i (aref *fact* (- i 1))) +mod+)
(aref *inv* i) (- +mod+
(mod (* (aref *inv* (rem +mod+ i))
(floor +mod+ i))
+mod+))
(aref *fact-inv* i) (mod (* (aref *inv* i)
(aref *fact-inv* (- i 1)))
+mod+))))
(initialize-binom)
(declaim (inline binom))
(defun binom (n k)
"Returns nCk, the number of k-combinations of n things without repetition."
(if (or (< n k) (< n 0) (< k 0))
0
(mod (* (aref *fact* n)
(mod (* (aref *fact-inv* k) (aref *fact-inv* (- n k))) +mod+))
+mod+)))
(declaim (inline perm))
(defun perm (n k)
"Returns nPk, the number of k-permutations of n things without repetition."
(if (or (< n k) (< n 0) (< k 0))
0
(mod (* (aref *fact* n) (aref *fact-inv* (- n k))) +mod+)))
(declaim (inline multichoose))
(defun multichoose (n k)
"Returns the number of k-combinations of n things with repetition."
(binom (+ n k -1) k))
(declaim (inline multinomial))
(defun multinomial (&rest ks)
"Returns the multinomial coefficient K!/k_1!k_2!...k_n! for K = k_1 + k_2 +
... + k_n. K must be equal to or smaller than
MOST-POSITIVE-FIXNUM. (multinomial) returns 1."
(let ((sum 0)
(result 1))
(declare ((integer 0 #.most-positive-fixnum) result sum))
(dolist (k ks)
(incf sum k)
(setq result
(mod (* result (aref *fact-inv* k)) +mod+)))
(mod (* result (aref *fact* sum)) +mod+)))
(define-compiler-macro multinomial (&rest args)
(case (length args)
((0 1) (mod 1 +mod+))
(otherwise
`(mod (* ,(reduce (lambda (x y) `(mod (* ,x ,y) +mod+))
args
:key (lambda (x) `(aref *fact-inv* ,x)))
(aref *fact* (+ ,@args)))
+mod+))))
(declaim (inline stirling2))
(defun stirling2 (n k)
"Returns the stirling number of the second kind S2(n, k). Time complexity is
O(klog(n))."
(declare ((integer 0 #.most-positive-fixnum) n k))
(labels ((mod-power (base exp)
(declare ((integer 0 #.most-positive-fixnum) base exp))
(loop with res of-type (integer 0 #.most-positive-fixnum) = 1
while (> exp 0)
when (oddp exp)
do (setq res (mod (* res base) +mod+))
do (setq base (mod (* base base) +mod+)
exp (ash exp -1))
finally (return res))))
(loop with result of-type fixnum = 0
for i from 0 to k
for delta = (mod (* (binom k i) (mod-power i n)) +mod+)
when (evenp (- k i))
do (incf result delta)
(when (>= result +mod+)
(decf result +mod+))
else
do (decf result delta)
(when (< result 0)
(incf result +mod+))
finally (return (mod (* result (aref *fact-inv* k)) +mod+)))))
(declaim (inline catalan))
(defun catalan (n)
"Returns the N-th Catalan number."
(declare ((integer 0 #.most-positive-fixnum) n))
(mod (* (aref *fact* (* 2 n))
(mod (* (aref *fact-inv* (+ n 1))
(aref *fact-inv* n))
+mod+))
+mod+))
(defpackage :cp/geometric-sequence
(:use :cl)
(:export #:make-geometric-sequence))
(in-package :cp/geometric-sequence)
(declaim (inline make-geometric-sequence))
(defun make-geometric-sequence (rate length modulus &key (scale 1) (element-type '(unsigned-byte 31)))
"Returns a vector of the given length: VECTOR[x] := SCALE * (RATE^x) mod
MODULUS."
(declare (fixnum rate scale)
((mod #.array-dimension-limit) length)
((integer 1 #.most-positive-fixnum) modulus))
(let ((result (make-array length :element-type element-type)))
(unless (zerop length)
(setf (aref result 0) (mod scale modulus))
(loop for i from 1 below length
do (setf (aref result i)
(mod (* rate (aref result (- i 1))) modulus))))
result))
(defpackage :cp/mod-power
(:use :cl)
(:export #:mod-power))
(in-package :cp/mod-power)
(declaim (inline mod-power))
(defun mod-power (base power modulus)
"Returns BASE^POWER mod MODULUS. Note: 0^0 = 1.
BASE := integer
POWER, MODULUS := non-negative fixnum"
(declare ((integer 0 #.most-positive-fixnum) modulus power)
(integer base))
(let ((base (mod base modulus))
(res (mod 1 modulus)))
(declare ((integer 0 #.most-positive-fixnum) base res))
(loop while (> power 0)
when (oddp power)
do (setq res (mod (* res base) modulus))
do (setq base (mod (* base base) modulus)
power (ash power -1)))
res))
(defpackage :cp/mod-inverse
(:use :cl)
#+sbcl (:import-from #:sb-c #:defoptimizer #:lvar-type #:integer-type-numeric-bounds
#:derive-type #:flushable #:foldable)
#+sbcl (:import-from :sb-kernel #:specifier-type)
(:export #:mod-inverse))
(in-package :cp/mod-inverse)
#+sbcl
(eval-when (:compile-toplevel :load-toplevel :execute)
(sb-c:defknown %mod-inverse ((integer 0) (integer 1)) (integer 0)
(flushable foldable)
:overwrite-fndb-silently t)
(sb-c:defknown mod-inverse (integer (integer 1)) (integer 0)
(flushable foldable)
:overwrite-fndb-silently t)
(defun derive-mod (modulus)
(let ((high (nth-value 1 (integer-type-numeric-bounds (lvar-type modulus)))))
(specifier-type (if (integerp high)
`(integer 0 (,high))
`(integer 0)))))
(defoptimizer (%mod-inverse derive-type) ((integer modulus))
(declare (ignore integer))
(derive-mod modulus))
(defoptimizer (mod-inverse derive-type) ((integer modulus))
(declare (ignore integer))
(derive-mod modulus)))
(defun %mod-inverse (integer modulus)
(declare (optimize (speed 3) (safety 0))
#+sbcl (sb-ext:muffle-conditions sb-ext:compiler-note))
(macrolet ((frob (stype)
`(let ((a integer)
(b modulus)
(u 1)
(v 0))
(declare (,stype a b u v))
(loop until (zerop b)
for quot = (floor a b)
do (decf a (the ,stype (* quot b)))
(rotatef a b)
(decf u (the ,stype (* quot v)))
(rotatef u v))
(if (< u 0)
(+ u modulus)
u))))
(typecase modulus
((unsigned-byte 31) (frob (signed-byte 32)))
((unsigned-byte 62) (frob (signed-byte 63)))
(otherwise (frob integer)))))
(declaim (inline mod-inverse))
(defun mod-inverse (integer modulus)
"Solves ax = 1 mod m. Signals DIVISION-BY-ZERO when INTEGER and MODULUS are
not coprime."
(let* ((integer (mod integer modulus))
(result (%mod-inverse integer modulus)))
(unless (or (= 1 (mod (* integer result) modulus)) (= 1 modulus))
(error 'division-by-zero
:operands (list integer modulus)
:operation 'mod-inverse))
result))
(defpackage :cp/ntt
(:use :cl :cp/mod-inverse :cp/barrett)
(:export #:define-ntt #:check-ntt-vector #:ntt-int #:ntt-vector)
(:documentation
"Provides fast number theoretic transform.
Reference:
https://github.com/ei1333/library/blob/master/math/fft/number-theoretic-transform-friendly-mod-int.cpp
https://github.com/atcoder/ac-library/tree/master/atcoder"))
(in-package :cp/ntt)
(deftype ntt-int () '(unsigned-byte 31))
(deftype ntt-vector () '(simple-array ntt-int (*)))
(eval-when (:compile-toplevel :load-toplevel :execute)
(declaim (inline %tzcount))
(defun %tzcount (x)
"Returns the number of the trailing zero bits. Note that (%TZCOUNT 0) = -1."
(- (integer-length (logand x (- x))) 1))
(defun %mod-power (base exp modulus)
(declare (ntt-int base exp modulus))
(let ((res 1))
(declare (ntt-int res))
(loop while (> exp 0)
when (oddp exp)
do (setq res (mod (* res base) modulus))
do (setq base (mod (* base base) modulus)
exp (ash exp -1)))
res))
(defun %mod-inverse (x modulus)
(%mod-power x (- modulus 2) modulus))
(defun %calc-generator (modulus)
"MODULUS must be prime."
(declare (ntt-int modulus))
(assert (>= modulus 2))
(case modulus
(2 1)
(167772161 3)
(469762049 3)
(754974721 11)
(998244353 3)
(otherwise
(let ((divs (make-array 20 :element-type 'ntt-int :initial-element 0))
(end 1)
(x (floor (- modulus 1) 2)))
(declare ((integer 0 #.most-positive-fixnum) x))
(setf (aref divs 0) 2)
(loop while (evenp x)
do (setq x (floor x 2)))
(loop for i of-type ntt-int from 3 by 2
while (<= (* i i) x)
when (zerop (mod x i))
do (setf (aref divs end) i)
(incf end)
(loop while (zerop (mod x i))
do (setq x (floor x i))))
(when (> x 1)
(setf (aref divs end) x)
(incf end))
(loop for g of-type ntt-int from 2
do (dotimes (i end (return-from %calc-generator g))
(when (= 1 (%mod-power g (floor (- modulus 1) (aref divs i)) modulus))
(return)))))))))
(declaim (ftype (function * (values ntt-vector &optional)) %adjust-array))
(defun %adjust-array (vector length)
"This function always copies VECTOR. (ANSI CL doesn't state whether
CL:ADJUST-ARRAY should copy the given array or not.)"
(declare (optimize (speed 3))
(vector vector)
((mod #.array-dimension-limit) length))
(let ((vector (coerce vector 'ntt-vector)))
(if (= (length vector) length)
(copy-seq vector)
(let ((res (make-array length :element-type 'ntt-int :initial-element 0)))
(replace res vector)
res))))
(defun check-ntt-vector (vector)
(declare (optimize (speed 3))
(vector vector))
(let ((len (length vector)))
(assert (and (zerop (logand len (- len 1))) ; power of two
(typep len 'ntt-int)))))
(defmacro define-ntt (modulus &key ntt inverse-ntt convolve &environment env)
(assert (constantp modulus env))
(let* ((modulus #+sbcl (sb-int:constant-form-value modulus env) #-sbcl modulus)
(ntt (or ntt (intern "NTT!")))
(inverse-ntt (or inverse-ntt (intern "INVERSE-NTT!")))
(convolve (or convolve (intern "CONVOLVE")))
(ntt-base (gensym "*NTT-BASE*"))
(ntt-inv-base (gensym "*NTT-INV-BASE*"))
(base-size (%tzcount (- modulus 1)))
(root (%calc-generator modulus))
(modulus (sb-int:constant-form-value modulus env)))
(declare (ntt-int modulus))
`(progn
(declaim (ntt-vector ,ntt-base ,ntt-inv-base))
(sb-ext:define-load-time-global ,ntt-base
(make-array ,base-size :element-type 'ntt-int))
(sb-ext:define-load-time-global ,ntt-inv-base
(make-array ,base-size :element-type 'ntt-int))
(dotimes (i ,base-size)
(setf (aref ,ntt-base i)
(mod (- (%mod-power ,root (ash (- ,modulus 1) (- (+ i 2))) ,modulus))
,modulus)
(aref ,ntt-inv-base i)
(%mod-inverse (aref ,ntt-base i) ,modulus)))
(declaim (ftype (function * (values ntt-vector &optional)) ,ntt))
(defun ,ntt (vector)
(declare (optimize (speed 3) (safety 0))
(vector vector))
(check-ntt-vector vector)
(labels ((mod* (x y) (fast-mod (* x y) ,modulus))
(mod+ (x y) (%himod (+ x y) ,modulus))
(mod- (x y) (%lomod (- x y) ,modulus)))
(declare (inline mod* mod+ mod-))
(let* ((vector (coerce vector 'ntt-vector))
(len (length vector))
(base ,ntt-base))
(declare (ntt-vector vector base)
(ntt-int len))
(when (<= len 1)
(return-from ,ntt vector))
(loop for m of-type ntt-int = (ash len -1) then (ash m -1)
while (> m 0)
for w of-type ntt-int = 1
for k of-type ntt-int = 0
do (loop for s of-type ntt-int from 0 below len by (* 2 m)
do (loop for i from s below (+ s m)
for j from (+ s m)
for x = (aref vector i)
for y = (mod* (aref vector j) w)
do (setf (aref vector i) (mod+ x y)
(aref vector j) (mod- x y)))
(incf k)
(setq w (mod* w (aref base (%tzcount k))))))
vector)))
(declaim (ftype (function * (values ntt-vector &optional)) ,inverse-ntt))
(defun ,inverse-ntt (vector &optional inverse)
(declare (optimize (speed 3) (safety 0))
(vector vector))
(check-ntt-vector vector)
(labels ((mod* (x y) (fast-mod (* x y) ,modulus))
(mod+ (x y) (%himod (+ x y) ,modulus))
(mod- (x y) (%lomod (- x y) ,modulus)))
(declare (inline mod* mod+ mod-))
(let* ((vector (coerce vector 'ntt-vector))
(len (length vector))
(base ,ntt-inv-base))
(declare (ntt-vector vector base)
(ntt-int len))
(when (<= len 1)
(return-from ,inverse-ntt vector))
(loop for m of-type ntt-int = 1 then (ash m 1)
while (< m len)
for w of-type ntt-int = 1
for k of-type ntt-int = 0
do (loop for s of-type ntt-int from 0 below len by (* 2 m)
do (loop for i from s below (+ s m)
for j from (+ s m)
for x = (aref vector i)
for y = (aref vector j)
do (setf (aref vector i) (mod+ x y)
(aref vector j) (mod* (mod- x y) w)))
(incf k)
(setq w (mod* w (aref base (%tzcount k))))))
(when inverse
(let ((inv-len (mod-inverse len ,modulus)))
(dotimes (i len)
(setf (aref vector i) (mod* inv-len (aref vector i))))))
vector)))
(declaim (ftype (function * (values ntt-vector &optional)) ,convolve))
(defun ,convolve (vector1 vector2)
(declare (optimize (speed 3))
(vector vector1 vector2))
;; TODO: if (EQ VECTOR1 VECTOR2) holds, the number of FFTs can be
;; reduced.
(let* ((len1 (length vector1))
(len2 (length vector2))
(mul-len (max 0 (- (+ len1 len2) 1)))
(vector1 (coerce vector1 'ntt-vector))
(vector2 (coerce vector2 'ntt-vector)))
(declare (ntt-vector vector1 vector2)
((mod #.array-dimension-limit) mul-len))
(when (or (zerop len1) (zerop len2))
(return-from ,convolve (make-array 0 :element-type 'ntt-int)))
;; naive convolution
(when (<= (min len1 len2) 64)
(let ((res (make-array mul-len :element-type 'ntt-int :initial-element 0)))
(declare (optimize (speed 3) (safety 0)))
(dotimes (d mul-len)
;; 0 <= i <= deg1, 0 <= j <= deg2
(loop with coef of-type ntt-int = 0
for i from (max 0 (- d (- len2 1))) to (min d (- len1 1))
for j = (- d i)
do (setq coef (fast-mod (+ coef (* (aref vector1 i) (aref vector2 j)))
,modulus))
finally (setf (aref res d) coef)))
(return-from ,convolve res)))
(let* (;; power of two ceiling
(required-len (ash 1 (integer-length (max 0 (- mul-len 1)))))
(vector1 (,ntt (%adjust-array vector1 required-len)))
(vector2 (,ntt (%adjust-array vector2 required-len))))
(dotimes (i required-len)
(setf (aref vector1 i)
(fast-mod (* (aref vector1 i) (aref vector2 i)) ,modulus)))
(subseq (,inverse-ntt vector1 t) 0 mul-len)))))))
#+(or)
(define-ntt +mod+)
(defpackage :cp/polynomial-ntt
(:use :cl :cp/ntt :cp/mod-inverse :cp/mod-power :cp/static-mod)
(:export #:poly-multiply #:poly-inverse #:poly-floor #:poly-mod #:poly-sub #:poly-add
#:multipoint-eval #:poly-total-prod #:chirp-z #:bostan-mori
#:poly-differentiate! #:poly-integrate #:poly-log #:poly-exp #:poly-power))
(in-package :cp/polynomial-ntt)
;; TODO: integrate with cp/polynomial
(define-ntt +mod+
:convolve poly-multiply)
(declaim (inline %adjust))
(defun %adjust (vector size)
(declare (ntt-vector vector))
(if (or (null size) (= size (length vector)))
vector
(let ((res (make-array size :element-type 'ntt-int :initial-element 0)))
(replace res vector)
res)))
(declaim (inline %power-of-two-ceiling))
(defun %power-of-two-ceiling (x)
(declare (ntt-int x))
(ash 1 (integer-length (- x 1))))
;; (declaim (ftype (function * (values ntt-vector &optional)) poly-inverse))
;; (defun poly-inverse (poly &optional result-length)
;; (declare (optimize (speed 3))
;; (vector poly)
;; ((or null fixnum) result-length))
;; (let* ((poly (coerce poly 'ntt-vector))
;; (n (length poly)))
;; (declare (ntt-vector poly))
;; (when (or (zerop n)
;; (zerop (aref poly 0)))
;; (error 'division-by-zero
;; :operation #'poly-inverse
;; :operands (list poly)))
;; (let ((res (make-array 1
;; :element-type 'ntt-int
;; :initial-element (mod-inverse (aref poly 0) +mod+)))
;; (result-length (or result-length n)))
;; (declare (ntt-vector res))
;; (loop for i of-type ntt-int = 1 then (ash i 1)
;; while (< i result-length)
;; for decr = (poly-multiply (poly-multiply res res)
;; (subseq poly 0 (min (length poly) (* 2 i))))
;; for decr-len = (length decr)
;; do (setq res (%adjust res (* 2 i) :initial-element 0))
;; (dotimes (j (* 2 i))
;; (setf (aref res j)
;; (mod (the ntt-int
;; (+ (mod (* 2 (aref res j)) +mod+)
;; (if (>= j decr-len) 0 (- +mod+ (aref decr j)))))
;; +mod+))))
;; (%adjust res result-length))))
;; Reference: https://opt-cp.com/fps-fast-algorithms/
(declaim (ftype (function * (values ntt-vector &optional)) poly-inverse))
(defun poly-inverse (poly &optional result-length)
(declare (optimize (speed 3))
(vector poly)
((or null fixnum) result-length))
(let* ((poly (coerce poly 'ntt-vector))
(n (length poly)))
(declare (ntt-vector poly))
(when (or (zerop n)
(zerop (aref poly 0)))
(error 'division-by-zero
:operation #'poly-inverse
:operands (list poly)))
(let* ((result-length (or result-length n))
(res (make-array 1
:element-type 'ntt-int
:initial-element (mod-inverse (aref poly 0) +mod+))))
(declare (ntt-vector res))
(loop for i of-type ntt-int = 1 then (ash i 1)
while (< i result-length)
for f of-type ntt-vector = (make-array (* 2 i) :element-type 'ntt-int
:initial-element 0)
for g of-type ntt-vector = (make-array (* 2 i) :element-type 'ntt-int
:initial-element 0)
do (replace f poly :end2 (min n (* 2 i)))
(replace g res)
(ntt! f)
(ntt! g)
(dotimes (j (* 2 i))
(setf (aref f j) (mod (* (aref g j) (aref f j)) +mod+)))
(inverse-ntt! f)
(replace f f :start1 0 :end1 i :start2 i :end2 (* 2 i))
(fill f 0 :start i :end (* 2 i))
(ntt! f)
(dotimes (j (* 2 i))
(setf (aref f j) (mod (* (aref g j) (aref f j)) +mod+)))
(inverse-ntt! f)
(let ((inv-len (mod-inverse (* 2 i) +mod+)))
(setq inv-len (mod (* inv-len (- +mod+ inv-len))
+mod+))
(dotimes (j i)
(setf (aref f j) (mod (* inv-len (aref f j)) +mod+)))
(setq res (%adjust res (* 2 i)))
(replace res f :start1 i)))
(%adjust res result-length))))
(declaim (ftype (function * (values ntt-vector &optional)) poly-floor))
(defun poly-floor (poly1 poly2)
(declare (optimize (speed 3))
(vector poly1 poly2))
(let* ((poly1 (coerce poly1 'ntt-vector))
(poly2 (coerce poly2 'ntt-vector))
(len1 (+ 1 (or (position 0 poly1 :from-end t :test-not #'eql) -1)))
(len2 (+ 1 (or (position 0 poly2 :from-end t :test-not #'eql) -1))))
(when (> len2 len1)
(return-from poly-floor (make-array 0 :element-type 'ntt-int)))
(setq poly1 (nreverse (subseq poly1 0 len1))
poly2 (nreverse (subseq poly2 0 len2)))
(let* ((res-len (+ 1 (- len1 len2)))
(res (%adjust (poly-multiply poly1 (poly-inverse poly2 res-len))
res-len)))
(nreverse res))))
(declaim (ftype (function * (values ntt-vector &optional)) poly-sub))
(defun poly-sub (poly1 poly2)
(declare (optimize (speed 3))
(vector poly1 poly2))
(let* ((poly1 (coerce poly1 'ntt-vector))
(poly2 (coerce poly2 'ntt-vector))
(len (max (length poly1) (length poly2)))
(res (make-array len :element-type 'ntt-int :initial-element 0)))
(replace res poly1)
(dotimes (i (length poly2))
(let ((value (- (aref res i) (aref poly2 i))))
(setf (aref res i)
(if (< value 0)
(+ value +mod+)
value))))
(let ((end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1))))
(%adjust res end))))
(declaim (ftype (function * (values ntt-vector &optional)) poly-add))
(defun poly-add (poly1 poly2)
(declare (optimize (speed 3))
(vector poly1 poly2))
(let* ((poly1 (coerce poly1 'ntt-vector))
(poly2 (coerce poly2 'ntt-vector))
(len (max (length poly1) (length poly2)))
(res (make-array len :element-type 'ntt-int :initial-element 0)))
(replace res poly1)
(dotimes (i (length poly2))
(setf (aref res i) (mod (+ (aref res i) (aref poly2 i)) +mod+)))
(let ((end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1))))
(%adjust res end))))
(declaim (ftype (function * (values ntt-vector &optional)) poly-mod))
(defun poly-mod (poly1 poly2)
(declare (optimize (speed 3))
(vector poly1 poly2))
(let ((poly1 (coerce poly1 'ntt-vector))
(poly2 (coerce poly2 'ntt-vector)))
(when (loop for x across poly1 always (zerop x))
(return-from poly-mod (make-array 0 :element-type 'ntt-int)))
(let* ((res (poly-sub poly1 (poly-multiply (poly-floor poly1 poly2) poly2)))
(end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1))))
(subseq res 0 end))))
(declaim (ftype (function * (values ntt-vector &optional)) poly-total-prod))
(defun poly-total-prod (polys)
"Returns the total polynomial product: polys[0] * polys[1] * ... * polys[n-1]."
(declare (optimize (speed 3))
(vector polys))
(let* ((n (length polys))
(dp (make-array n :element-type t)))
(declare ((mod #.array-dimension-limit) n))
(when (zerop n)
(return-from poly-total-prod (make-array 1 :element-type 'ntt-int :initial-element 1)))
(replace dp polys)
(loop for width of-type (mod #.array-dimension-limit) = 1 then (ash width 1)
while (< width n)
do (loop for i of-type (mod #.array-dimension-limit) from 0 by (* width 2)
while (< (+ i width) n)
do (setf (aref dp i)
(poly-multiply (aref dp i) (aref dp (+ i width))))))
(coerce (the vector (aref dp 0)) 'ntt-vector)))
(declaim (ftype (function * (values ntt-vector &optional)) multipoint-eval))
(defun multipoint-eval (poly points)
"The length of POINTS must be a power of two."
(declare (optimize (speed 3))
(vector poly points)
#+sbcl (sb-ext:muffle-conditions style-warning))
(check-ntt-vector points)
(let* ((poly (coerce poly 'ntt-vector))
(points (coerce points 'ntt-vector))
(len (length points))
(table (make-array (max 0 (- (* 2 len) 1)) :element-type 'ntt-vector))
(res (make-array len :element-type 'ntt-int)))
(unless (zerop len)
(labels ((%build (l r pos)
(declare ((mod #.array-dimension-limit) l r pos))
(if (= (- r l) 1)
(let ((lin (make-array 2 :element-type 'ntt-int)))
(setf (aref lin 0) (- +mod+ (aref points l)) ;; NOTE: non-zero
(aref lin 1) 1)
(setf (aref table pos) lin))
(let ((mid (ash (+ l r) -1)))
(%build l mid (+ 1 (* pos 2)))
(%build mid r (+ 2 (* pos 2)))
(setf (aref table pos)
(poly-multiply (aref table (+ 1 (* pos 2)))
(aref table (+ 2 (* pos 2)))))))))
(%build 0 len 0))
(labels ((%eval (poly l r pos)
(declare ((mod #.array-dimension-limit) l r pos))
(if (= (- r l) 1)
(let ((tmp (poly-mod poly (aref table pos))))
(setf (aref res l) (if (zerop (length tmp)) 0 (aref tmp 0))))
(let ((mid (ash (+ l r) -1)))
(%eval (poly-mod poly (aref table (+ (* 2 pos) 1)))
l mid (+ (* 2 pos) 1))
(%eval (poly-mod poly (aref table (+ (* 2 pos) 2)))
mid r (+ (* 2 pos) 2))))))
(%eval poly 0 len 0)))
res))
;; not tested
(declaim (ftype (function * (values ntt-vector &optional)) chirp-z))
(defun chirp-z (poly base length)
"Does multipoint evaluation of POLY with powers of BASE: P(base^0), P(base^1), ...,
P(base^(length-1)). BASE must be coprime with modulus. Time complexity is
O((N+MOD)log(N+MOD)).
Reference:
https://codeforces.com/blog/entry/83532"
(declare (optimize (speed 3))
(vector poly)
(ntt-int base length))
(when (zerop (length poly))
(return-from chirp-z (make-array length :element-type 'ntt-int :initial-element 0)))
(let* ((poly (coerce poly 'ntt-vector))
(binv (mod-inverse base +mod+))
(n (length poly))
(m (max length n))
(n+m (+ n m))
(cs (make-array n :element-type 'ntt-int :initial-element 0))
(ds (make-array n+m :element-type 'ntt-int :initial-element 0)))
(declare (ntt-int n m n+m))
(dotimes (i n)
(setf (aref cs i) (mod (* (aref poly (- n 1 i))
(mod-power binv (ash (* (- n 1 i) (- n 2 i)) -1) +mod+))
+mod+)))
(dotimes (i n+m)
(setf (aref ds i) (mod-power base (ash (* i (- i 1)) -1) +mod+)))
(let ((result (subseq (poly-multiply cs ds)
(- n 1)
(+ (- n 1) length))))
(dotimes (i length)
(setf (aref result i)
(mod (* (aref result i)
(mod-power binv (ash (* i (- i 1)) -1) +mod+))
+mod+)))
result)))
(declaim (ftype (function * (values ntt-int &optional)) bostan-mori))
(defun bostan-mori (index num denom)
"Returns [x^index](num(x)/denom(x)).
Reference:
https://arxiv.org/abs/2008.08822
https://qiita.com/ryuhe1/items/da5acbcce4ac1911f47 (Japanese)"
(declare (optimize (speed 3))
(unsigned-byte index)
(vector num denom))
(labels ((even (p)
(let ((res (make-array (ceiling (length p) 2) :element-type 'ntt-int)))
(dotimes (i (length res))
(setf (aref res i)
(aref p (* 2 i))))
res))
(odd (p)
(let ((res (make-array (floor (length p) 2) :element-type 'ntt-int)))
(dotimes (i (length res))
(setf (aref res i) (aref p (+ 1 (* 2 i)))))
res))
(negate (p)
(let ((res (copy-seq p)))
(loop for i from 1 below (length res) by 2
do (setf (aref res i)
(if (zerop (aref res i))
0
(- +mod+ (aref res i)))))
res)))
(let ((num (coerce num 'ntt-vector))
(denom (coerce denom 'ntt-vector)))
(when (or (zerop (length denom))
(zerop (aref denom 0)))
(error 'division-by-zero
:operands (list num denom)
:operation 'bostan-mori))
(loop while (> index 0)
for denom- = (negate denom)
for u = (poly-multiply num denom-)
when (evenp index)
do (setq num (even u))
else
do (setq num (odd u))
do (setq denom (even (poly-multiply denom denom-))
index (ash index -1))
finally (return (rem (* (if (zerop (length num))
0
(aref num 0))
(mod-inverse (aref denom 0) +mod+))
+mod+))))))
(declaim (inline poly-differentiate!))
(defun poly-differentiate! (p)
"Returns the derivative of P."
(declare (vector p))
(let ((p (coerce p 'ntt-vector)))
(when (zerop (length p))
(return-from poly-differentiate! p))
(dotimes (i (- (length p) 1))
(declare (ntt-int i))
(setf (aref p i)
(mod (* (aref p (+ i 1)) (+ i 1)) +mod+)))
(let ((end (+ 1 (or (position 0 p :from-end t :end (- (length p) 1) :test-not #'eql)
-1))))
(subseq p 0 end))))
(declaim (ntt-vector *inv*))
(defparameter *inv* (make-array 2 :element-type 'ntt-int :initial-contents '(0 1)))
(defun fill-inv! (new-size)
(declare (optimize (speed 3))
((mod #.array-dimension-limit) new-size))
(let* ((old-size (length *inv*))
(new-size (%power-of-two-ceiling (max old-size new-size))))
(when (< old-size new-size)
(loop with inv of-type ntt-vector = (%adjust *inv* new-size)
for x from old-size below new-size
do (setf (aref inv x)
(- +mod+
(mod (* (aref inv (rem +mod+ x)) (floor +mod+ x))
+mod+)))
finally (setq *inv* inv)))))
(declaim (inline poly-integrate))
(defun poly-integrate (p)
"Returns an indefinite integral of P. Assumes the integration constant to
be zero."
(declare (vector p))
(let* ((p (coerce p 'ntt-vector))
(n (length p)))
(when (zerop n)
(return-from poly-integrate (make-array 0 :element-type 'ntt-int)))
(fill-inv! (+ n 1))
(let ((result (make-array (+ n 1) :element-type 'ntt-int :initial-element 0))
(inv *inv*))
(dotimes (i n)
(setf (aref result (+ i 1))
(mod (* (aref p i) (aref inv (+ i 1)))
+mod+)))
result)))
(declaim (ftype (function * (values ntt-vector &optional)) poly-log))
(defun poly-log (poly &optional result-length)
(declare (optimize (speed 3))
(vector poly)
((or null (integer 1 (#.array-dimension-limit))) result-length))
(let* ((poly (coerce poly 'ntt-vector))
(result-length (or result-length (length poly))))
(assert (= 1 (aref poly 0)))
(let ((res (poly-integrate (%adjust (poly-multiply (poly-differentiate! (copy-seq poly))
(poly-inverse poly result-length))
(- result-length 1)))))
(%adjust res result-length))))
(declaim (ftype (function * (values ntt-vector &optional)) poly-exp))
(defun poly-exp (poly &optional result-length)
(declare (optimize (speed 3))
(vector poly)
((or null (mod #.array-dimension-limit)) result-length))
(assert (or (zerop (length poly)) (zerop (aref poly 0))))
(let ((poly (coerce poly 'ntt-vector))
(result-length (or result-length (length poly)))
(res (make-array 1 :element-type 'ntt-int :initial-element 1)))
(loop until (>= (length res) result-length)
for new-len of-type (mod #.array-dimension-limit) = (* 2 (length res))
for log = (poly-log res new-len)
do (loop for i from 0 below (min new-len (length poly))
do (setf (aref log i)
(mod (- (aref poly i) (aref log i))
+mod+)))
(setf (aref log 0) (mod (+ 1 (aref log 0)) +mod+)
res (%adjust (poly-multiply res log) new-len)))
(%adjust res result-length)))
(declaim (ftype (function * (values ntt-vector &optional)) poly-power))
(defun poly-power (poly exp &optional result-length)
(declare (optimize (speed 3))
(vector poly)
((integer 0 #.most-positive-fixnum) exp)
((or null (mod #.array-dimension-limit)) result-length))
(let* ((poly (coerce poly 'ntt-vector))
(result-length (or result-length (length poly)))
(init-pos (position 0 poly :test-not #'eql)))
(when (or (null init-pos)
(zerop result-length)
(> (* init-pos exp) result-length))
(return-from poly-power
(make-array result-length :element-type 'ntt-int :initial-element 0)))
(let ((tmp (subseq poly init-pos)))
(let ((inv (mod-inverse (aref poly init-pos) +mod+)))
(dotimes (i (length tmp))
(setf (aref tmp i) (mod (* (aref tmp i) inv) +mod+))))
(setq tmp (poly-log tmp result-length))
(let ((exp (mod exp +mod+)))
(dotimes (i (length tmp))
(setf (aref tmp i) (mod (* (aref tmp i) exp) +mod+)))
(setq tmp (poly-exp tmp result-length)))
(let ((power (mod-power (aref poly init-pos) exp +mod+)))
(dotimes (i (length tmp))
(setf (aref tmp i) (mod (* (aref tmp i) power) +mod+))))
(if (zerop init-pos)
tmp
(let ((res (make-array result-length :element-type 'ntt-int :initial-element 0)))
(replace res tmp :start1 (min (length res) (* init-pos exp)))
res)))))
;; BEGIN_USE_PACKAGE
(eval-when (:compile-toplevel :load-toplevel :execute)
(use-package :cp/polynomial-ntt :cl-user))
(eval-when (:compile-toplevel :load-toplevel :execute)
(use-package :cp/geometric-sequence :cl-user))
(eval-when (:compile-toplevel :load-toplevel :execute)
(use-package :cp/binom-mod-prime :cl-user))
(eval-when (:compile-toplevel :load-toplevel :execute)
(use-package :cp/read-fixnum :cl-user))
(eval-when (:compile-toplevel :load-toplevel :execute)
(use-package :cp/mod-operations :cl-user))
(in-package :cl-user)
;;;
;;; Body
;;;
(declaim ((simple-array uint31 (*)) *power/2*))
(sb-ext:define-load-time-global *power/2*
(make-geometric-sequence (floor (+ +mod+ 1) 2) 51000 +mod+))
(declaim (inline calc))
(defun calc (x)
(mod* (aref *fact* (* 2 x))
(aref *power/2* x)
(aref *fact-inv* x)))
(defun main ()
(declare #.*opt*)
(let* ((n (read))
(counter (make-array 100000 :element-type 'uint31 :initial-element 0)))
(declare (uint31 n))
(dotimes (i (* 2 n))
(let ((h (- (read-fixnum) 1)))
(incf (aref counter h))))
(let* ((counter (delete 0 counter))
(len (length counter))
(polys (make-array len :element-type t)))
(dotimes (i len)
(let* ((c (aref counter i))
(deg (floor c 2))
(poly (make-array (+ deg 1) :element-type 'uint31 :initial-element 0)))
(dotimes (j (length poly))
(setf (aref poly j)
(mod* (binom c (* 2 j)) (calc j))))
(setf (aref polys i) poly)))
(let ((dp (poly-total-prod polys))
(res 0))
(declare (uint31 res))
(dotimes (i (length dp))
(let ((val (mod* (calc (- n i)) (aref dp i))))
(if (evenp i)
(incfmod res val)
(decfmod res val))))
(println res)))))
#-swank (main)
;;;
;;; Test
;;;
#+swank
(progn
(defparameter *lisp-file-pathname* (uiop:current-lisp-file-pathname))
(setq *default-pathname-defaults* (uiop:pathname-directory-pathname *lisp-file-pathname*))
(uiop:chdir *default-pathname-defaults*)
(defparameter *dat-pathname* (uiop:merge-pathnames* "test.dat" *lisp-file-pathname*))
(defparameter *problem-url* "https://atcoder.jp/contests/abl/tasks/abl_f"))
#+swank
(defun gen-dat ()
(uiop:with-output-file (out *dat-pathname* :if-exists :supersede)
(format out "")))
#+swank
(defun bench (&optional (out (make-broadcast-stream)))
(time (run *dat-pathname* out)))
#+(and sbcl (not swank))
(eval-when (:compile-toplevel)
(when sb-c::*undefined-warnings*
(error "undefined warnings: ~{~A~^ ~}" sb-c::*undefined-warnings*)))
;; To run: (5am:run! :sample)
#+swank
(5am:test :sample
(5am:is
(equal "2
"
(run "2
1
1
2
3
" nil)))
(5am:is
(equal "516
"
(run "5
30
10
20
40
20
10
10
30
50
60
" nil))))
Submission Info
Submission Time
2021-11-02 06:26:18+0900
Task
F - Heights and Pairs
User
sansaqua
Language
Common Lisp (SBCL 2.0.3)
Score
600
Code Size
53122 Byte
Status
AC
Exec Time
110 ms
Memory
38652 KiB
Compile Error
; file: /imojudge/sandbox/Main.lisp
; in: DEFUN POLY-TOTAL-PROD
; (REPLACE CP/POLYNOMIAL-NTT::DP CP/POLYNOMIAL-NTT::POLYS)
;
; note: unable to
; optimize
; due to type uncertainty:
; The second argument is a VECTOR, not a SIMPLE-VECTOR.
; file: /imojudge/sandbox/Main.lisp
; in: DEFUN CHIRP-Z
; (SUBSEQ
; (CP/POLYNOMIAL-NTT:POLY-MULTIPLY CP/POLYNOMIAL-NTT::CS
; CP/POLYNOMIAL-NTT::DS)
; (- CP/POLYNOMIAL-NTT::N 1) (+ (- CP/POLYNOMIAL-NTT::N 1) LENGTH))
; --> LET* OR LET IF
; ==>
; LENGTH
;
; note: deleting unreachable code
; file: /imojudge/sandbox/Main.lisp
; in: DEFUN BOSTAN-MORI
; (> CP/POLYNOMIAL-NTT::INDEX 0)
;
; note: forced to do FAST-IF->-ZERO (cost 8)
; unable to do inline fixnum comparison (cost 3) because:
; The first argument is a UNSIGNED-BYTE, not a FIXNUM.
; (ASH CP/POLYNOMIAL-NTT::INDEX -1)
;
; note: forced to do full call
; unable to do inline ASH (cost 2) because:
; The first argument is a (INTEGER 1), not a FIXNUM.
; The result is a (VALUES UNSIGNED-BYTE &OPTIONAL), not a (VALUES FIXNUM
; &REST T).
; unable to do inline ASH (cost 3) because:
; The first argument is a (INTEGER 1), not a (UNSIGNED-BYTE 64).
; The result is a (VALUES UNSIGNED-BYTE &OPTIONAL), not a (VALUES
; (UNSIGNED-BYTE
; 64)
; &REST T).
; etc.
; (DEFUN CP/POLYNOMIAL-NTT:BOSTAN-MORI
; (CP/POLYNOMIAL-NTT::INDEX CP/POLYNOMIAL-NTT::NUM
; CP/POLYNOMIAL-NTT::DENOM)
; "Returns [x^index](num(x)/denom(x)).
;
; Reference:
; https://arxiv.org/abs/2008.08822
; https://qiita.com/ryuhe1/items/da5acbcce4ac1911f47 (Japanese)"
; (DECLARE (OPTIMIZE (SPEED 3))
; (UNSIGNED-BYTE CP/POLYNOMIAL-NTT::INDEX)
; (VECTOR CP/PO...
Judge Result
Set Name
Sample
All
Score / Max Score
0 / 0
600 / 600
Status
Set Name
Test Cases
Sample
example0.txt, example1.txt
All
000.txt, 001.txt, 002.txt, 003.txt, 004.txt, 005.txt, 006.txt, 007.txt, 008.txt, 009.txt, 010.txt, 011.txt, 012.txt, 013.txt, 014.txt, 015.txt, 016.txt, 017.txt, 018.txt, 019.txt, 020.txt, example0.txt, example1.txt
Case Name
Status
Exec Time
Memory
000.txt
AC
44 ms
36788 KiB
001.txt
AC
40 ms
28672 KiB
002.txt
AC
46 ms
29388 KiB
003.txt
AC
45 ms
29828 KiB
004.txt
AC
45 ms
30332 KiB
005.txt
AC
53 ms
30288 KiB
006.txt
AC
55 ms
30728 KiB
007.txt
AC
59 ms
30948 KiB
008.txt
AC
62 ms
31400 KiB
009.txt
AC
64 ms
31840 KiB
010.txt
AC
74 ms
32600 KiB
011.txt
AC
78 ms
33616 KiB
012.txt
AC
86 ms
34500 KiB
013.txt
AC
88 ms
35028 KiB
014.txt
AC
95 ms
35844 KiB
015.txt
AC
95 ms
35948 KiB
016.txt
AC
94 ms
36244 KiB
017.txt
AC
110 ms
37748 KiB
018.txt
AC
109 ms
38140 KiB
019.txt
AC
109 ms
38652 KiB
020.txt
AC
86 ms
37608 KiB
example0.txt
AC
22 ms
28536 KiB
example1.txt
AC
33 ms
28472 KiB