Submission #26982349


Source Code Expand

(in-package :cl-user)
(eval-when (:compile-toplevel :load-toplevel :execute)
  (defparameter *opt*
    #+swank '(optimize (speed 3) (safety 2))
    #-swank '(optimize (speed 3) (safety 0) (debug 0)))
  #+swank (ql:quickload '(:cl-debug-print :fiveam :cp/util) :silent t)
  #+swank (use-package :cp/util :cl-user)
  #-swank (set-dispatch-macro-character
           #\# #\> (lambda (s c p) (declare (ignore c p)) `(values ,(read s nil nil t))))
  #+sbcl (dolist (f '(:popcnt :sse4)) (pushnew f sb-c:*backend-subfeatures*))
  (setq *random-state* (make-random-state t)))
#-swank (eval-when (:compile-toplevel)
          (setq *break-on-signals* '(and warning (not style-warning))))
#+swank (set-dispatch-macro-character #\# #\> #'cl-debug-print:debug-print-reader)

(macrolet ((def (b)
             `(progn (deftype ,(intern (format nil "UINT~A" b)) () '(unsigned-byte ,b))
                     (deftype ,(intern (format nil "INT~A" b)) () '(signed-byte ,b))))
           (define-int-types (&rest bits) `(progn ,@(mapcar (lambda (b) `(def ,b)) bits))))
  (define-int-types 2 4 7 8 15 16 31 32 62 63 64))

(defconstant +mod+ 998244353)

(defmacro dbg (&rest forms)
  (declare (ignorable forms))
  #+swank (if (= (length forms) 1)
              `(format *error-output* "~A => ~A~%" ',(car forms) ,(car forms))
              `(format *error-output* "~A => ~A~%" ',forms `(,,@forms))))

(declaim (inline println))
(defun println (obj &optional (stream *standard-output*))
  (let ((*read-default-float-format*
          (if (typep obj 'double-float) 'double-float *read-default-float-format*)))
    (prog1 (princ obj stream) (terpri stream))))

;; BEGIN_INSERTED_CONTENTS
(defpackage :cp/barrett
  (:use :cl)
  (:import-from #:sb-ext #:truly-the)
  (:import-from #:sb-c
                #:deftransform
                #:derive-type
                #:defoptimizer
                #:defknown #:movable #:foldable #:flushable #:commutative
                #:always-translatable
                #:lvar-type
                #:lvar-value
                #:give-up-ir1-transform
                #:integer-type-numeric-bounds
                #:define-vop
                #:tn-offset)
  (:import-from #:sb-vm
                #:move #:inst #:rax-offset #:rdx-offset #:temp-reg-tn
                #:any-reg #:control-stack #:unsigned-reg
                #:positive-fixnum
                #:fixnumize #:ea)
  (:import-from #:sb-kernel #:specifier-type)
  (:import-from #:sb-int #:explicit-check #:constant-arg)
  (:export #:fast-mod #:%himod #:%lomod)
  (:documentation "Provides Barrett reduction."))
(in-package :cp/barrett)

(eval-when (:compile-toplevel :load-toplevel :execute)
  (defun derive-* (x y)
    (let ((high1 (nth-value 1 (integer-type-numeric-bounds (lvar-type x))))
          (high2 (nth-value 1 (integer-type-numeric-bounds (lvar-type y)))))
      (specifier-type (if (and (integerp high1) (integerp high2))
                          `(integer 0 ,(ash (* high1 high2) -62))
                          `(integer 0)))))

  (defun derive-mod (modulus)
    (let ((high (nth-value 1 (integer-type-numeric-bounds (lvar-type modulus)))))
      (specifier-type (if (integerp high)
                          `(integer 0 (,high))
                          `(integer 0)))))

  (defun gpr-tn-p (x)
    (declare (ignorable x))
    #.(if (find-symbol "GPR-TN-P" :sb-vm)
          `(funcall (intern "GPR-TN-P" :sb-vm) x)
          t)))

;; *-high62
(eval-when (:compile-toplevel :load-toplevel :execute)
  (defknown *-high62 ((unsigned-byte 62) (unsigned-byte 62)) (unsigned-byte 62)
      (movable foldable flushable commutative always-translatable)
    :overwrite-fndb-silently t)

  (defoptimizer (*-high62 derive-type) ((x y))
    (derive-* x y))

  (define-vop (fast-*-high62/fixnum)
    (:translate *-high62)
    (:policy :fast-safe)
    (:args (x :scs (any-reg) :target rax)
           (y :scs (any-reg control-stack)))
    (:arg-types positive-fixnum positive-fixnum)
    (:temporary (:sc any-reg :offset rax-offset
                 :from (:argument 0) :to :result)
                rax)
    (:temporary (:sc any-reg :offset rdx-offset :target r
                 :from :eval :to :result)
                rdx)
    (:results (r :scs (any-reg)))
    (:result-types positive-fixnum)
    (:note "inline *-high62")
    (:vop-var vop)
    (:save-p :compute-only)
    (:generator 6
                (move rax x)
                (inst mul rax y)
                (inst shl rdx 1)
                (move r rdx)))

  ;; NOTE: I'm not using it for now because registered constant is slower
  ;; (define-vop (fast-c-*-high62-/fixnum)
  ;;   (:translate *-high62)
  ;;   (:policy :fast-safe)
  ;;   (:args (x :scs (any-reg) :target rax))
  ;;   (:info y)
  ;;   (:arg-types positive-fixnum (:constant (unsigned-byte 62)))
  ;;   (:temporary (:sc any-reg :offset rax-offset
  ;;                :from (:argument 0) :to :result)
  ;;               rax)
  ;;   (:temporary (:sc any-reg :offset rdx-offset :target r
  ;;                :from :eval :to :result)
  ;;               rdx)
  ;;   (:results (r :scs (any-reg)))
  ;;   (:result-types positive-fixnum)
  ;;   (:note "inline constant *-high62")
  ;;   (:vop-var vop)
  ;;   (:save-p :compute-only)
  ;;   (:generator 5
  ;;               (move rax x)
  ;;               (inst mul rax (sb-c:register-inline-constant :qword (fixnumize y)))
  ;;               (inst shl rdx 1)
  ;;               (move r rdx)))

  (defun *-high62 (x y)
    (declare (explicit-check))
    (*-high62 x y)))

;; %himod
(eval-when (:compile-toplevel :load-toplevel :execute)
  (defknown %himod ((unsigned-byte 32) (unsigned-byte 31)) (unsigned-byte 31)
      (movable foldable flushable always-translatable)
    :overwrite-fndb-silently t)

  (defoptimizer (%himod derive-type) ((integer modulus))
    (declare (ignore integer))
    (derive-mod modulus))

  (define-vop (fast-c-%himod)
    (:translate %himod)
    (:policy :fast-safe)
    (:args (x :scs (any-reg) :target r))
    (:info m)
    (:arg-types positive-fixnum (:constant (unsigned-byte 31)))
    (:temporary (:sc any-reg :from :eval :to :result
                     ;; FIXME: hack to avoid collision of X and Y
                     :offset #.(tn-offset temp-reg-tn))
                y)
    (:results (r :scs (any-reg)))
    (:result-types positive-fixnum)
    (:note "inline constant %himod")
    (:vop-var vop)
    (:generator
     4
     ;; maybe verbose
     (assert (gpr-tn-p x))
     (assert (not (sb-c:location= x y)))
     (when (sb-c:tn-p m)
       (assert (sb-c:sc-is m sb-vm::immediate))
       (setq m (sb-c::tn-value m)))
     (setq m (fixnumize m))
     (move r x)
     (inst cmp r (- m 2))
     (inst lea y #.(if (fboundp 'ea)
                       `(sb-vm::ea (- m) r)
                       `(,(find-symbol "MAKE-EA" :sb-vm) :dword :disp (- m) :base r)))
     (inst cmov :a r y))))

;; %lomod
(eval-when (:compile-toplevel :load-toplevel :execute)
  (defknown %lomod ((signed-byte 32) (unsigned-byte 31)) (unsigned-byte 31)
      (movable foldable flushable always-translatable)
    :overwrite-fndb-silently t)

  (defoptimizer (%lomod derive-type) ((integer modulus))
    (declare (ignore integer))
    (derive-mod modulus))

  (define-vop (fast-c-%lomod)
    (:translate %lomod)
    (:policy :fast-safe)
    (:args (x :scs (any-reg) :target r))
    (:info m)
    (:arg-types fixnum (:constant (unsigned-byte 31)))
    (:temporary (:sc any-reg :from :eval :to :result :offset #.(tn-offset temp-reg-tn)) y)
    (:results (r :scs (any-reg)))
    (:result-types positive-fixnum)
    (:note "inline constant %lomod")
    (:vop-var vop)
    (:generator
     4
     ;; maybe verbose
     (assert (gpr-tn-p x))
     (assert (not (sb-c:location= x y)))
     (when (sb-c:tn-p m)
       (assert (sb-c:sc-is m sb-vm::immediate))
       (setq m (sb-c::tn-value m)))
     (setq m (fixnumize m))
     (move r x)
     ;; TODO: this instruction is rarely necessary because %LOMOD tends to be
     ;; called right after subtraction.
     (inst or r r)
     (inst lea y #.(if (fboundp 'ea)
                       `(sb-vm::ea m r)
                       `(,(find-symbol "MAKE-EA" :sb-vm) :dword :disp m :base r)))
     (inst cmov :l r y))))

;; fast-mod
(eval-when (:compile-toplevel :load-toplevel :execute)
  (defknown fast-mod (integer unsigned-byte) unsigned-byte
      (movable foldable flushable)
    :overwrite-fndb-silently t)

  (defoptimizer (fast-mod derive-type) ((integer modulus))
    (declare (ignore integer))
    (derive-mod modulus))

  (defun fast-mod (number divisor)
    "Is equivalent to CL:MOD for integer arguments"
    (declare (explicit-check))
    (mod number divisor))

  (deftransform fast-mod ((number divisor)
                          ((unsigned-byte 62) (constant-arg (unsigned-byte 31))) *
                          :important t)
    "convert FAST-MOD to barrett reduction"
    (let ((mod (lvar-value divisor)))
      (if (<= mod 1)
          (give-up-ir1-transform)
          (let ((m (floor (ash 1 62) mod)))
            `(let* ((q (*-high62 number ,m))
                    (x (truly-the (signed-byte 32) (- number (* q ,mod)))))
               (if (< x ,mod) x (- x ,mod))
               ;; Not so effective because branch prediction is usually correct?
               ;; (sb-ext:truly-the (mod ,mod) (%himod x ,mod))
               ))))))

(defpackage :cp/mod-operations
  (:use :cl)
  (:export #:define-mod-operations #:*modulus*)
  (:documentation "Provides modular arithmetic."))
(in-package :cp/mod-operations)

(eval-when (:compile-toplevel :load-toplevel :execute)
  (defvar *modulus* 0))
(declaim ((unsigned-byte 31) *modulus*)
         #+sbcl (sb-ext:always-bound *modulus*))

(defmacro define-mod-operations
    (divisor &optional (package #+sbcl (sb-int:sane-package) #-sbcl *package*))
  (let ((mod* (intern "MOD*" package))
        (mod+ (intern "MOD+" package))
        (mod- (intern "MOD-" package))
        (incfmod (intern "INCFMOD" package))
        (decfmod (intern "DECFMOD" package))
        (mulfmod (intern "MULFMOD" package)))
    `(progn
       (defun ,mod* (&rest args)
         (cond ((cdr args) (reduce (lambda (x y) (mod (* x y) ,divisor)) args))
               (args (mod (car args) ,divisor))
               (t (mod 1 ,divisor))))
       (defun ,mod+ (&rest args)
         (cond ((cdr args) (reduce (lambda (x y) (mod (+ x y) ,divisor)) args))
               (args (mod (car args) ,divisor))
               (t 0)))
       (defun ,mod- (&rest args)
         (if (cdr args)
             (reduce (lambda (x y) (mod (- x y) ,divisor)) args)
             (mod (- (car args)) ,divisor)))

       #+sbcl
       (eval-when (:compile-toplevel :load-toplevel :execute)
         (locally (declare (sb-ext:muffle-conditions warning))
           (sb-c:define-source-transform ,mod* (&rest args)
             (case (length args)
               (0 `(mod 1 ,',divisor))
               (1 `(mod ,(car args) ,',divisor))
               (otherwise (reduce (lambda (x y) `(mod (* ,x ,y) ,',divisor)) args))))
           (sb-c:define-source-transform ,mod+ (&rest args)
             (case (length args)
               (0 0)
               (1 `(mod ,(car args) ,',divisor))
               (otherwise (reduce (lambda (x y) `(mod (+ ,x ,y) ,',divisor)) args))))
           (sb-c:define-source-transform ,mod- (&rest args)
             (case (length args)
               (0 (values nil t))
               (1 `(mod (- ,(car args)) ,',divisor))
               (otherwise (reduce (lambda (x y) `(mod (- ,x ,y) ,',divisor)) args))))))

       (define-modify-macro ,incfmod (delta)
         (lambda (x y) (mod (+ x y) ,divisor)))
       (define-modify-macro ,decfmod (delta)
         (lambda (x y) (mod (- x y) ,divisor)))
       (define-modify-macro ,mulfmod (multiplier)
         (lambda (x y) (mod (* x y) ,divisor))))))

(define-mod-operations cl-user::+mod+ :cl-user)

(defpackage :cp/read-fixnum
  (:use :cl)
  (:export #:read-fixnum))
(in-package :cp/read-fixnum)

(declaim (ftype (function * (values fixnum &optional)) read-fixnum))
(defun read-fixnum (&optional (in *standard-input*))
  "NOTE: cannot read -2^62"
  (macrolet ((%read-byte ()
               `(the (unsigned-byte 8)
                     #+swank (char-code (read-char in nil #\Nul))
                     #-swank (sb-impl::ansi-stream-read-byte in nil #.(char-code #\Nul) nil))))
    (let* ((minus nil)
           (result (loop (let ((byte (%read-byte)))
                           (cond ((<= 48 byte 57)
                                  (return (- byte 48)))
                                 ((zerop byte) ; #\Nul
                                  (error "Read EOF or #\Nul."))
                                 ((= byte #.(char-code #\-))
                                  (setq minus t)))))))
      (declare ((integer 0 #.most-positive-fixnum) result))
      (loop
        (let* ((byte (%read-byte)))
          (if (<= 48 byte 57)
              (setq result (+ (- byte 48)
                              (* 10 (the (integer 0 #.(floor most-positive-fixnum 10))
                                         result))))
              (return (if minus (- result) result))))))))

(defpackage :cp/static-mod
  (:use :cl)
  (:export #:+mod+))
(in-package :cp/static-mod)

(defconstant +mod+ (if (boundp 'cl-user::+mod+)
                       (symbol-value 'cl-user::+mod+)
                       998244353))

(defpackage :cp/binom-mod-prime
  (:use :cl :cp/static-mod)
  (:export #:binom #:perm #:multinomial #:stirling2 #:catalan #:multichoose
           #:*fact* #:*fact-inv* #:*inv*)
  (:documentation
   "Provides tables of factorials, inverses, inverses ot factorials etc.
modulo prime.

build: O(n)
query: O(1)
"))
(in-package :cp/binom-mod-prime)

;; TODO: non-global handling
(defconstant +binom-size+ 110000)

(declaim ((simple-array (unsigned-byte 31) (*)) *fact* *fact-inv* *inv*))
(sb-ext:define-load-time-global *fact*
  (make-array +binom-size+ :element-type '(unsigned-byte 31))
  "table of factorials")
(sb-ext:define-load-time-global *fact-inv*
  (make-array +binom-size+ :element-type '(unsigned-byte 31))
  "table of inverses of factorials")
(sb-ext:define-load-time-global *inv*
  (make-array +binom-size+ :element-type '(unsigned-byte 31))
  "table of inverses of non-negative integers")

(defun initialize-binom ()
  (declare (optimize (speed 3) (safety 0)))
  (setf (aref *fact* 0) 1
        (aref *fact* 1) 1
        (aref *fact-inv* 0) 1
        (aref *fact-inv* 1) 1
        (aref *inv* 1) 1)
  (loop for i from 2 below +binom-size+
        do (setf (aref *fact* i) (mod (* i (aref *fact* (- i 1))) +mod+)
                 (aref *inv* i) (- +mod+
                                   (mod (* (aref *inv* (rem +mod+ i))
                                           (floor +mod+ i))
                                        +mod+))
                 (aref *fact-inv* i) (mod (* (aref *inv* i)
                                             (aref *fact-inv* (- i 1)))
                                          +mod+))))

(initialize-binom)

(declaim (inline binom))
(defun binom (n k)
  "Returns nCk, the number of k-combinations of n things without repetition."
  (if (or (< n k) (< n 0) (< k 0))
      0
      (mod (* (aref *fact* n)
              (mod (* (aref *fact-inv* k) (aref *fact-inv* (- n k))) +mod+))
           +mod+)))

(declaim (inline perm))
(defun perm (n k)
  "Returns nPk, the number of k-permutations of n things without repetition."
  (if (or (< n k) (< n 0) (< k 0))
      0
      (mod (* (aref *fact* n) (aref *fact-inv* (- n k))) +mod+)))

(declaim (inline multichoose))
(defun multichoose (n k)
  "Returns the number of k-combinations of n things with repetition."
  (binom (+ n k -1) k))

(declaim (inline multinomial))
(defun multinomial (&rest ks)
  "Returns the multinomial coefficient K!/k_1!k_2!...k_n! for K = k_1 + k_2 +
... + k_n. K must be equal to or smaller than
MOST-POSITIVE-FIXNUM. (multinomial) returns 1."
  (let ((sum 0)
        (result 1))
    (declare ((integer 0 #.most-positive-fixnum) result sum))
    (dolist (k ks)
      (incf sum k)
      (setq result
            (mod (* result (aref *fact-inv* k)) +mod+)))
    (mod (* result (aref *fact* sum)) +mod+)))

(define-compiler-macro multinomial (&rest args)
  (case (length args)
    ((0 1) (mod 1 +mod+))
    (otherwise
     `(mod (* ,(reduce (lambda (x y) `(mod (* ,x ,y) +mod+))
                       args
                       :key (lambda (x) `(aref *fact-inv* ,x)))
              (aref *fact* (+ ,@args)))
           +mod+))))

(declaim (inline stirling2))
(defun stirling2 (n k)
  "Returns the stirling number of the second kind S2(n, k). Time complexity is
O(klog(n))."
  (declare ((integer 0 #.most-positive-fixnum) n k))
  (labels ((mod-power (base exp)
             (declare ((integer 0 #.most-positive-fixnum) base exp))
             (loop with res of-type (integer 0 #.most-positive-fixnum) = 1
                   while (> exp 0)
                   when (oddp exp)
                   do (setq res (mod (* res base) +mod+))
                   do (setq base (mod (* base base) +mod+)
                            exp (ash exp -1))
                   finally (return res))))
    (loop with result of-type fixnum = 0
          for i from 0 to k
          for delta = (mod (* (binom k i) (mod-power i n)) +mod+)
          when (evenp (- k i))
          do (incf result delta)
             (when (>= result +mod+)
               (decf result +mod+))
          else
          do (decf result delta)
             (when (< result 0)
               (incf result +mod+))
          finally (return (mod (* result (aref *fact-inv* k)) +mod+)))))

(declaim (inline catalan))
(defun catalan (n)
  "Returns the N-th Catalan number."
  (declare ((integer 0 #.most-positive-fixnum) n))
  (mod (* (aref *fact* (* 2 n))
          (mod (* (aref *fact-inv* (+ n 1))
                  (aref *fact-inv* n))
               +mod+))
       +mod+))

(defpackage :cp/geometric-sequence
  (:use :cl)
  (:export #:make-geometric-sequence))
(in-package :cp/geometric-sequence)

(declaim (inline make-geometric-sequence))
(defun make-geometric-sequence (rate length modulus &key (scale 1) (element-type '(unsigned-byte 31)))
  "Returns a vector of the given length: VECTOR[x] := SCALE * (RATE^x) mod
MODULUS."
  (declare (fixnum rate scale)
           ((mod #.array-dimension-limit) length)
           ((integer 1 #.most-positive-fixnum) modulus))
  (let ((result (make-array length :element-type element-type)))
    (unless (zerop length)
      (setf (aref result 0) (mod scale modulus))
      (loop for i from 1 below length
            do (setf (aref result i)
                     (mod (* rate (aref result (- i 1))) modulus))))
    result))

(defpackage :cp/mod-power
  (:use :cl)
  (:export #:mod-power))
(in-package :cp/mod-power)

(declaim (inline mod-power))
(defun mod-power (base power modulus)
  "Returns BASE^POWER mod MODULUS. Note: 0^0 = 1.

BASE := integer
POWER, MODULUS := non-negative fixnum"
  (declare ((integer 0 #.most-positive-fixnum) modulus power)
           (integer base))
  (let ((base (mod base modulus))
        (res (mod 1 modulus)))
    (declare ((integer 0 #.most-positive-fixnum) base res))
    (loop while (> power 0)
          when (oddp power)
          do (setq res (mod (* res base) modulus))
          do (setq base (mod (* base base) modulus)
                   power (ash power -1)))
    res))

(defpackage :cp/mod-inverse
  (:use :cl)
  #+sbcl (:import-from #:sb-c #:defoptimizer #:lvar-type #:integer-type-numeric-bounds
                       #:derive-type #:flushable #:foldable)
  #+sbcl (:import-from :sb-kernel #:specifier-type)
  (:export #:mod-inverse))
(in-package :cp/mod-inverse)

#+sbcl
(eval-when (:compile-toplevel :load-toplevel :execute)
  (sb-c:defknown %mod-inverse ((integer 0) (integer 1)) (integer 0)
      (flushable foldable)
    :overwrite-fndb-silently t)
  (sb-c:defknown mod-inverse (integer (integer 1)) (integer 0)
      (flushable foldable)
    :overwrite-fndb-silently t)
  (defun derive-mod (modulus)
    (let ((high (nth-value 1 (integer-type-numeric-bounds (lvar-type modulus)))))
      (specifier-type (if (integerp high)
                          `(integer 0 (,high))
                          `(integer 0)))))
  (defoptimizer (%mod-inverse derive-type) ((integer modulus))
    (declare (ignore integer))
    (derive-mod modulus))
  (defoptimizer (mod-inverse derive-type) ((integer modulus))
    (declare (ignore integer))
    (derive-mod modulus)))

(defun %mod-inverse (integer modulus)
  (declare (optimize (speed 3) (safety 0))
           #+sbcl (sb-ext:muffle-conditions sb-ext:compiler-note))
  (macrolet ((frob (stype)
               `(let ((a integer)
                      (b modulus)
                      (u 1)
                      (v 0))
                  (declare (,stype a b u v))
                  (loop until (zerop b)
                        for quot = (floor a b)
                        do (decf a (the ,stype (* quot b)))
                           (rotatef a b)
                           (decf u (the ,stype (* quot v)))
                           (rotatef u v))
                  (if (< u 0)
                      (+ u modulus)
                      u))))
    (typecase modulus
      ((unsigned-byte 31) (frob (signed-byte 32)))
      ((unsigned-byte 62) (frob (signed-byte 63)))
      (otherwise (frob integer)))))

(declaim (inline mod-inverse))
(defun mod-inverse (integer modulus)
  "Solves ax = 1 mod m. Signals DIVISION-BY-ZERO when INTEGER and MODULUS are
not coprime."
  (let* ((integer (mod integer modulus))
         (result (%mod-inverse integer modulus)))
    (unless (or (= 1 (mod (* integer result) modulus)) (= 1 modulus))
      (error 'division-by-zero
             :operands (list integer modulus)
             :operation 'mod-inverse))
    result))

(defpackage :cp/ntt
  (:use :cl :cp/mod-inverse :cp/barrett)
  (:export #:define-ntt #:check-ntt-vector #:ntt-int #:ntt-vector)
  (:documentation
   "Provides fast number theoretic transform.

Reference:
https://github.com/ei1333/library/blob/master/math/fft/number-theoretic-transform-friendly-mod-int.cpp
https://github.com/atcoder/ac-library/tree/master/atcoder"))
(in-package :cp/ntt)

(deftype ntt-int () '(unsigned-byte 31))
(deftype ntt-vector () '(simple-array ntt-int (*)))

(eval-when (:compile-toplevel :load-toplevel :execute)
  (declaim (inline %tzcount))
  (defun %tzcount (x)
    "Returns the number of the trailing zero bits. Note that (%TZCOUNT 0) = -1."
    (- (integer-length (logand x (- x))) 1))
  (defun %mod-power (base exp modulus)
    (declare (ntt-int base exp modulus))
    (let ((res 1))
      (declare (ntt-int res))
      (loop while (> exp 0)
            when (oddp exp)
            do (setq res (mod (* res base) modulus))
            do (setq base (mod (* base base) modulus)
                     exp (ash exp -1)))
      res))
  (defun %mod-inverse (x modulus)
    (%mod-power x (- modulus 2) modulus))
  (defun %calc-generator (modulus)
    "MODULUS must be prime."
    (declare (ntt-int modulus))
    (assert (>= modulus 2))
    (case modulus
      (2 1)
      (167772161 3)
      (469762049 3)
      (754974721 11)
      (998244353 3)
      (otherwise
       (let ((divs (make-array 20 :element-type 'ntt-int :initial-element 0))
             (end 1)
             (x (floor (- modulus 1) 2)))
         (declare ((integer 0 #.most-positive-fixnum) x))
         (setf (aref divs 0) 2)
         (loop while (evenp x)
               do (setq x (floor x 2)))
         (loop for i of-type ntt-int from 3 by 2
               while (<= (* i i) x)
               when (zerop (mod x i))
               do (setf (aref divs end) i)
                  (incf end)
                  (loop while (zerop (mod x i))
                        do (setq x (floor x i))))
         (when (> x 1)
           (setf (aref divs end) x)
           (incf end))
         (loop for g of-type ntt-int from 2
               do (dotimes (i end (return-from %calc-generator g))
                    (when (= 1 (%mod-power g (floor (- modulus 1) (aref divs i)) modulus))
                      (return)))))))))

(declaim (ftype (function * (values ntt-vector &optional)) %adjust-array))
(defun %adjust-array (vector length)
  "This function always copies VECTOR. (ANSI CL doesn't state whether
CL:ADJUST-ARRAY should copy the given array or not.)"
  (declare (optimize (speed 3))
           (vector vector)
           ((mod #.array-dimension-limit) length))
  (let ((vector (coerce vector 'ntt-vector)))
    (if (= (length vector) length)
        (copy-seq vector)
        (let ((res (make-array length :element-type 'ntt-int :initial-element 0)))
          (replace res vector)
          res))))

(defun check-ntt-vector (vector)
  (declare (optimize (speed 3))
           (vector vector))
  (let ((len (length vector)))
    (assert (and (zerop (logand len (- len 1))) ; power of two
                 (typep len 'ntt-int)))))

(defmacro define-ntt (modulus &key ntt inverse-ntt convolve &environment env)
  (assert (constantp modulus env))
  (let* ((modulus #+sbcl (sb-int:constant-form-value modulus env) #-sbcl modulus)
         (ntt (or ntt (intern "NTT!")))
         (inverse-ntt (or inverse-ntt (intern "INVERSE-NTT!")))
         (convolve (or convolve (intern "CONVOLVE")))
         (ntt-base (gensym "*NTT-BASE*"))
         (ntt-inv-base (gensym "*NTT-INV-BASE*"))
         (base-size (%tzcount (- modulus 1)))
         (root (%calc-generator modulus))
         (modulus (sb-int:constant-form-value modulus env)))
    (declare (ntt-int modulus))
    `(progn
       (declaim (ntt-vector ,ntt-base ,ntt-inv-base))
       (sb-ext:define-load-time-global ,ntt-base
         (make-array ,base-size :element-type 'ntt-int))
       (sb-ext:define-load-time-global ,ntt-inv-base
         (make-array ,base-size :element-type 'ntt-int))
       (dotimes (i ,base-size)
         (setf (aref ,ntt-base i)
               (mod (- (%mod-power ,root (ash (- ,modulus 1) (- (+ i 2))) ,modulus))
                    ,modulus)
               (aref ,ntt-inv-base i)
               (%mod-inverse (aref ,ntt-base i) ,modulus)))

       (declaim (ftype (function * (values ntt-vector &optional)) ,ntt))
       (defun ,ntt (vector)
         (declare (optimize (speed 3) (safety 0))
                  (vector vector))
         (check-ntt-vector vector)
         (labels ((mod* (x y) (fast-mod (* x y) ,modulus))
                  (mod+ (x y) (%himod (+ x y) ,modulus))
                  (mod- (x y) (%lomod (- x y) ,modulus)))
           (declare (inline mod* mod+ mod-))
           (let* ((vector (coerce vector 'ntt-vector))
                  (len (length vector))
                  (base ,ntt-base))
             (declare (ntt-vector vector base)
                      (ntt-int len))
             (when (<= len 1)
               (return-from ,ntt vector))
             (loop for m of-type ntt-int = (ash len -1) then (ash m -1)
                   while (> m 0)
                   for w of-type ntt-int = 1
                   for k of-type ntt-int = 0
                   do (loop for s of-type ntt-int from 0 below len by (* 2 m)
                            do (loop for i from s below (+ s m)
                                     for j from (+ s m)
                                     for x = (aref vector i)
                                     for y = (mod* (aref vector j) w)
                                     do (setf (aref vector i) (mod+ x y)
                                              (aref vector j) (mod- x y)))
                               (incf k)
                               (setq w (mod* w (aref base (%tzcount k))))))
             vector)))

       (declaim (ftype (function * (values ntt-vector &optional)) ,inverse-ntt))
       (defun ,inverse-ntt (vector &optional inverse)
         (declare (optimize (speed 3) (safety 0))
                  (vector vector))
         (check-ntt-vector vector)
         (labels ((mod* (x y) (fast-mod (* x y) ,modulus))
                  (mod+ (x y) (%himod (+ x y) ,modulus))
                  (mod- (x y) (%lomod (- x y) ,modulus)))
           (declare (inline mod* mod+ mod-))
           (let* ((vector (coerce vector 'ntt-vector))
                  (len (length vector))
                  (base ,ntt-inv-base))
             (declare (ntt-vector vector base)
                      (ntt-int len))
             (when (<= len 1)
               (return-from ,inverse-ntt vector))
             (loop for m of-type ntt-int = 1 then (ash m 1)
                   while (< m len)
                   for w of-type ntt-int = 1
                   for k of-type ntt-int = 0
                   do (loop for s of-type ntt-int from 0 below len by (* 2 m)
                            do (loop for i from s below (+ s m)
                                     for j from (+ s m)
                                     for x = (aref vector i)
                                     for y = (aref vector j)
                                     do (setf (aref vector i) (mod+ x y)
                                              (aref vector j) (mod* (mod- x y) w)))
                               (incf k)
                               (setq w (mod* w (aref base (%tzcount k))))))
             (when inverse
               (let ((inv-len (mod-inverse len ,modulus)))
                 (dotimes (i len)
                   (setf (aref vector i) (mod* inv-len (aref vector i))))))
             vector)))

       (declaim (ftype (function * (values ntt-vector &optional)) ,convolve))
       (defun ,convolve (vector1 vector2)
         (declare (optimize (speed 3))
                  (vector vector1 vector2))
         ;; TODO: if (EQ VECTOR1 VECTOR2) holds, the number of FFTs can be
         ;; reduced.
         (let* ((len1 (length vector1))
                (len2 (length vector2))
                (mul-len (max 0 (- (+ len1 len2) 1)))
                (vector1 (coerce vector1 'ntt-vector))
                (vector2 (coerce vector2 'ntt-vector)))
           (declare (ntt-vector vector1 vector2)
                    ((mod #.array-dimension-limit) mul-len))
           (when (or (zerop len1) (zerop len2))
             (return-from ,convolve (make-array 0 :element-type 'ntt-int)))
           ;; naive convolution
           (when (<= (min len1 len2) 64)
             (let ((res (make-array mul-len :element-type 'ntt-int :initial-element 0)))
               (declare (optimize (speed 3) (safety 0)))
               (dotimes (d mul-len)
                 ;; 0 <= i <= deg1, 0 <= j <= deg2
                 (loop with coef of-type ntt-int = 0
                       for i from (max 0 (- d (- len2 1))) to (min d (- len1 1))
                       for j = (- d i)
                       do (setq coef (fast-mod (+ coef (* (aref vector1 i) (aref vector2 j)))
                                               ,modulus))
                       finally (setf (aref res d) coef)))
               (return-from ,convolve res)))
           (let* (;; power of two ceiling
                  (required-len (ash 1 (integer-length (max 0 (- mul-len 1)))))
                  (vector1 (,ntt (%adjust-array vector1 required-len)))
                  (vector2 (,ntt (%adjust-array vector2 required-len))))
             (dotimes (i required-len)
               (setf (aref vector1 i)
                     (fast-mod (* (aref vector1 i) (aref vector2 i)) ,modulus)))
             (subseq (,inverse-ntt vector1 t) 0 mul-len)))))))

#+(or)
(define-ntt +mod+)


(defpackage :cp/polynomial-ntt
  (:use :cl :cp/ntt :cp/mod-inverse :cp/mod-power :cp/static-mod)
  (:export #:poly-multiply #:poly-inverse #:poly-floor #:poly-mod #:poly-sub #:poly-add
           #:multipoint-eval #:poly-total-prod #:chirp-z #:bostan-mori
           #:poly-differentiate! #:poly-integrate #:poly-log #:poly-exp #:poly-power))
(in-package :cp/polynomial-ntt)

;; TODO: integrate with cp/polynomial

(define-ntt +mod+
  :convolve poly-multiply)

(declaim (inline %adjust))
(defun %adjust (vector size)
  (declare (ntt-vector vector))
  (if (or (null size) (= size (length vector)))
      vector
      (let ((res (make-array size :element-type 'ntt-int :initial-element 0)))
        (replace res vector)
        res)))

(declaim (inline %power-of-two-ceiling))
(defun %power-of-two-ceiling (x)
  (declare (ntt-int x))
  (ash 1 (integer-length (- x 1))))

;; (declaim (ftype (function * (values ntt-vector &optional)) poly-inverse))
;; (defun poly-inverse (poly &optional result-length)
;;   (declare (optimize (speed 3))
;;            (vector poly)
;;            ((or null fixnum) result-length))
;;   (let* ((poly (coerce poly 'ntt-vector))
;;          (n (length poly)))
;;     (declare (ntt-vector poly))
;;     (when (or (zerop n)
;;               (zerop (aref poly 0)))
;;       (error 'division-by-zero
;;              :operation #'poly-inverse
;;              :operands (list poly)))
;;     (let ((res (make-array 1
;;                            :element-type 'ntt-int
;;                            :initial-element (mod-inverse (aref poly 0) +mod+)))
;;           (result-length (or result-length n)))
;;       (declare (ntt-vector res))
;;       (loop for i of-type ntt-int = 1 then (ash i 1)
;;             while (< i result-length)
;;             for decr = (poly-multiply (poly-multiply res res)
;;                                       (subseq poly 0 (min (length poly) (* 2 i))))
;;             for decr-len = (length decr)
;;             do (setq res (%adjust res (* 2 i) :initial-element 0))
;;                (dotimes (j (* 2 i))
;;                  (setf (aref res j)
;;                        (mod (the ntt-int
;;                                  (+ (mod (* 2 (aref res j)) +mod+)
;;                                     (if (>= j decr-len) 0 (- +mod+ (aref decr j)))))
;;                             +mod+))))
;;       (%adjust res result-length))))

;; Reference: https://opt-cp.com/fps-fast-algorithms/
(declaim (ftype (function * (values ntt-vector &optional)) poly-inverse))
(defun poly-inverse (poly &optional result-length)
  (declare (optimize (speed 3))
           (vector poly)
           ((or null fixnum) result-length))
  (let* ((poly (coerce poly 'ntt-vector))
         (n (length poly)))
    (declare (ntt-vector poly))
    (when (or (zerop n)
              (zerop (aref poly 0)))
      (error 'division-by-zero
             :operation #'poly-inverse
             :operands (list poly)))
    (let* ((result-length (or result-length n))
           (res (make-array 1
                            :element-type 'ntt-int
                            :initial-element (mod-inverse (aref poly 0) +mod+))))
      (declare (ntt-vector res))
      (loop for i of-type ntt-int = 1 then (ash i 1)
            while (< i result-length)
            for f of-type ntt-vector = (make-array (* 2 i) :element-type 'ntt-int
                                                           :initial-element 0)
            for g of-type ntt-vector = (make-array (* 2 i) :element-type 'ntt-int
                                                           :initial-element 0)
            do (replace f poly :end2 (min n (* 2 i)))
               (replace g res)
               (ntt! f)
               (ntt! g)
               (dotimes (j (* 2 i))
                 (setf (aref f j) (mod (* (aref g j) (aref f j)) +mod+)))
               (inverse-ntt! f)
               (replace f f :start1 0 :end1 i :start2 i :end2 (* 2 i))
               (fill f 0 :start i :end (* 2 i))
               (ntt! f)
               (dotimes (j (* 2 i))
                 (setf (aref f j) (mod (* (aref g j) (aref f j)) +mod+)))
               (inverse-ntt! f)
               (let ((inv-len (mod-inverse (* 2 i) +mod+)))
                 (setq inv-len (mod (* inv-len (- +mod+ inv-len))
                                    +mod+))
                 (dotimes (j i)
                   (setf (aref f j) (mod (* inv-len (aref f j)) +mod+)))
                 (setq res (%adjust res (* 2 i)))
                 (replace res f :start1 i)))
      (%adjust res result-length))))

(declaim (ftype (function * (values ntt-vector &optional)) poly-floor))
(defun poly-floor (poly1 poly2)
  (declare (optimize (speed 3))
           (vector poly1 poly2))
  (let* ((poly1 (coerce poly1 'ntt-vector))
         (poly2 (coerce poly2 'ntt-vector))
         (len1 (+ 1 (or (position 0 poly1 :from-end t :test-not #'eql) -1)))
         (len2 (+ 1 (or (position 0 poly2 :from-end t :test-not #'eql) -1))))
    (when (> len2 len1)
      (return-from poly-floor (make-array 0 :element-type 'ntt-int)))
    (setq poly1 (nreverse (subseq poly1 0 len1))
          poly2 (nreverse (subseq poly2 0 len2)))
    (let* ((res-len (+ 1 (- len1 len2)))
           (res (%adjust (poly-multiply poly1 (poly-inverse poly2 res-len))
                         res-len)))
      (nreverse res))))

(declaim (ftype (function * (values ntt-vector &optional)) poly-sub))
(defun poly-sub (poly1 poly2)
  (declare (optimize (speed 3))
           (vector poly1 poly2))
  (let* ((poly1 (coerce poly1 'ntt-vector))
         (poly2 (coerce poly2 'ntt-vector))
         (len (max (length poly1) (length poly2)))
         (res (make-array len :element-type 'ntt-int :initial-element 0)))
    (replace res poly1)
    (dotimes (i (length poly2))
      (let ((value (- (aref res i) (aref poly2 i))))
        (setf (aref res i)
              (if (< value 0)
                  (+ value +mod+)
                  value))))
    (let ((end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1))))
      (%adjust res end))))

(declaim (ftype (function * (values ntt-vector &optional)) poly-add))
(defun poly-add (poly1 poly2)
  (declare (optimize (speed 3))
           (vector poly1 poly2))
  (let* ((poly1 (coerce poly1 'ntt-vector))
         (poly2 (coerce poly2 'ntt-vector))
         (len (max (length poly1) (length poly2)))
         (res (make-array len :element-type 'ntt-int :initial-element 0)))
    (replace res poly1)
    (dotimes (i (length poly2))
      (setf (aref res i) (mod (+ (aref res i) (aref poly2 i)) +mod+)))
    (let ((end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1))))
      (%adjust res end))))

(declaim (ftype (function * (values ntt-vector &optional)) poly-mod))
(defun poly-mod (poly1 poly2)
  (declare (optimize (speed 3))
           (vector poly1 poly2))
  (let ((poly1 (coerce poly1 'ntt-vector))
        (poly2 (coerce poly2 'ntt-vector)))
    (when (loop for x across poly1 always (zerop x))
      (return-from poly-mod (make-array 0 :element-type 'ntt-int)))
    (let* ((res (poly-sub poly1 (poly-multiply (poly-floor poly1 poly2) poly2)))
           (end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1))))
      (subseq res 0 end))))

(declaim (ftype (function * (values ntt-vector &optional)) poly-total-prod))
(defun poly-total-prod (polys)
  "Returns the total polynomial product: polys[0] * polys[1] * ... * polys[n-1]."
  (declare (optimize (speed 3))
           (vector polys))
  (let* ((n (length polys))
         (dp (make-array n :element-type t)))
    (declare ((mod #.array-dimension-limit) n))
    (when (zerop n)
      (return-from poly-total-prod (make-array 1 :element-type 'ntt-int :initial-element 1)))
    (replace dp polys)
    (loop for width of-type (mod #.array-dimension-limit) = 1 then (ash width 1)
          while (< width n)
          do (loop for i of-type (mod #.array-dimension-limit) from 0 by (* width 2)
                   while (< (+ i width) n)
                   do (setf (aref dp i)
                            (poly-multiply (aref dp i) (aref dp (+ i width))))))
    (coerce (the vector (aref dp 0)) 'ntt-vector)))

(declaim (ftype (function * (values ntt-vector &optional)) multipoint-eval))
(defun multipoint-eval (poly points)
  "The length of POINTS must be a power of two."
  (declare (optimize (speed 3))
           (vector poly points)
           #+sbcl (sb-ext:muffle-conditions style-warning))
  (check-ntt-vector points)
  (let* ((poly (coerce poly 'ntt-vector))
         (points (coerce points 'ntt-vector))
         (len (length points))
         (table (make-array (max 0 (- (* 2 len) 1)) :element-type 'ntt-vector))
         (res (make-array len :element-type 'ntt-int)))
    (unless (zerop len)
      (labels ((%build (l r pos)
                 (declare ((mod #.array-dimension-limit) l r pos))
                 (if (= (- r l) 1)
                     (let ((lin (make-array 2 :element-type 'ntt-int)))
                       (setf (aref lin 0) (- +mod+ (aref points l)) ;; NOTE: non-zero
                             (aref lin 1) 1)
                       (setf (aref table pos) lin))
                     (let ((mid (ash (+ l r) -1)))
                       (%build l mid (+ 1 (* pos 2)))
                       (%build mid r (+ 2 (* pos 2)))
                       (setf (aref table pos)
                             (poly-multiply (aref table (+ 1 (* pos 2)))
                                            (aref table (+ 2 (* pos 2)))))))))
        (%build 0 len 0))
      (labels ((%eval (poly l r pos)
                 (declare ((mod #.array-dimension-limit) l r pos))
                 (if (= (- r l) 1)
                     (let ((tmp (poly-mod poly (aref table pos))))
                       (setf (aref res l) (if (zerop (length tmp)) 0 (aref tmp 0))))
                     (let ((mid (ash (+ l r) -1)))
                       (%eval (poly-mod poly (aref table (+ (* 2 pos) 1)))
                              l mid (+ (* 2 pos) 1))
                       (%eval (poly-mod poly (aref table (+ (* 2 pos) 2)))
                              mid r (+ (* 2 pos) 2))))))
        (%eval poly 0 len 0)))
    res))

;; not tested
(declaim (ftype (function * (values ntt-vector &optional)) chirp-z))
(defun chirp-z (poly base length)
  "Does multipoint evaluation of POLY with powers of BASE: P(base^0), P(base^1), ...,
P(base^(length-1)). BASE must be coprime with modulus. Time complexity is
O((N+MOD)log(N+MOD)).

Reference:
https://codeforces.com/blog/entry/83532"
  (declare (optimize (speed 3))
           (vector poly)
           (ntt-int base length))
  (when (zerop (length poly))
    (return-from chirp-z (make-array length :element-type 'ntt-int :initial-element 0)))
  (let* ((poly (coerce poly 'ntt-vector))
         (binv (mod-inverse base +mod+))
         (n (length poly))
         (m (max length n))
         (n+m (+ n m))
         (cs (make-array n :element-type 'ntt-int :initial-element 0))
         (ds (make-array n+m :element-type 'ntt-int :initial-element 0)))
    (declare (ntt-int n m n+m))
    (dotimes (i n)
      (setf (aref cs i) (mod (* (aref poly (- n 1 i))
                                (mod-power binv (ash (* (- n 1 i) (- n 2 i)) -1) +mod+))
                             +mod+)))
    (dotimes (i n+m)
      (setf (aref ds i) (mod-power base (ash (* i (- i 1)) -1) +mod+)))
    (let ((result (subseq (poly-multiply cs ds)
                          (- n 1)
                          (+ (- n 1) length))))
      (dotimes (i length)
        (setf (aref result i)
              (mod (* (aref result i)
                      (mod-power binv (ash (* i (- i 1)) -1) +mod+))
                   +mod+)))
      result)))

(declaim (ftype (function * (values ntt-int &optional)) bostan-mori))
(defun bostan-mori (index num denom)
  "Returns [x^index](num(x)/denom(x)).

Reference:
https://arxiv.org/abs/2008.08822
https://qiita.com/ryuhe1/items/da5acbcce4ac1911f47 (Japanese)"
  (declare (optimize (speed 3))
           (unsigned-byte index)
           (vector num denom))
  (labels ((even (p)
             (let ((res (make-array (ceiling (length p) 2) :element-type 'ntt-int)))
               (dotimes (i (length res))
                 (setf (aref res i)
                       (aref p (* 2 i))))
               res))
           (odd (p)
             (let ((res (make-array (floor (length p) 2) :element-type 'ntt-int)))
               (dotimes (i (length res))
                 (setf (aref res i) (aref p (+ 1 (* 2 i)))))
               res))
           (negate (p)
             (let ((res (copy-seq p)))
               (loop for i from 1 below (length res) by 2
                     do (setf (aref res i)
                              (if (zerop (aref res i))
                                  0
                                  (- +mod+ (aref res i)))))
               res)))
    (let ((num (coerce num 'ntt-vector))
          (denom (coerce denom 'ntt-vector)))
      (when (or (zerop (length denom))
                (zerop (aref denom 0)))
        (error 'division-by-zero
               :operands (list num denom)
               :operation 'bostan-mori))
      (loop while (> index 0)
            for denom- = (negate denom)
            for u = (poly-multiply num denom-)
            when (evenp index)
            do (setq num (even u))
            else
            do (setq num (odd u))
            do (setq denom (even (poly-multiply denom denom-))
                     index (ash index -1))
            finally (return (rem (* (if (zerop (length num))
                                        0
                                        (aref num 0))
                                    (mod-inverse (aref denom 0) +mod+))
                                 +mod+))))))

(declaim (inline poly-differentiate!))
(defun poly-differentiate! (p)
  "Returns the derivative of P."
  (declare (vector p))
  (let ((p (coerce p 'ntt-vector)))
    (when (zerop (length p))
      (return-from poly-differentiate! p))
    (dotimes (i (- (length p) 1))
      (declare (ntt-int i))
      (setf (aref p i)
            (mod (* (aref p (+ i 1)) (+ i 1)) +mod+)))
    (let ((end (+ 1 (or (position 0 p :from-end t :end (- (length p) 1) :test-not #'eql)
                        -1))))
      (subseq p 0 end))))

(declaim (ntt-vector *inv*))
(defparameter *inv* (make-array 2 :element-type 'ntt-int :initial-contents '(0 1)))

(defun fill-inv! (new-size)
  (declare (optimize (speed 3))
           ((mod #.array-dimension-limit) new-size))
  (let* ((old-size (length *inv*))
         (new-size (%power-of-two-ceiling (max old-size new-size))))
    (when (< old-size new-size)
      (loop with inv of-type ntt-vector = (%adjust *inv* new-size)
            for x from old-size below new-size
            do (setf (aref inv x)
                     (- +mod+
                        (mod (* (aref inv (rem +mod+ x)) (floor +mod+ x))
                             +mod+)))
            finally (setq *inv* inv)))))

(declaim (inline poly-integrate))
(defun poly-integrate (p)
  "Returns an indefinite integral of P. Assumes the integration constant to
be zero."
  (declare (vector p))
  (let* ((p (coerce p 'ntt-vector))
         (n (length p)))
    (when (zerop n)
      (return-from poly-integrate (make-array 0 :element-type 'ntt-int)))
    (fill-inv! (+ n 1))
    (let ((result (make-array (+ n 1) :element-type 'ntt-int :initial-element 0))
          (inv *inv*))
      (dotimes (i n)
        (setf (aref result (+ i 1))
              (mod (* (aref p i) (aref inv (+ i 1)))
                   +mod+)))
      result)))

(declaim (ftype (function * (values ntt-vector &optional)) poly-log))
(defun poly-log (poly &optional result-length)
  (declare (optimize (speed 3))
           (vector poly)
           ((or null (integer 1 (#.array-dimension-limit))) result-length))
  (let* ((poly (coerce poly 'ntt-vector))
         (result-length (or result-length (length poly))))
    (assert (= 1 (aref poly 0)))
    (let ((res (poly-integrate (%adjust (poly-multiply (poly-differentiate! (copy-seq poly))
                                                       (poly-inverse poly result-length))
                                        (- result-length 1)))))
      (%adjust res result-length))))

(declaim (ftype (function * (values ntt-vector &optional)) poly-exp))
(defun poly-exp (poly &optional result-length)
  (declare (optimize (speed 3))
           (vector poly)
           ((or null (mod #.array-dimension-limit)) result-length))
  (assert (or (zerop (length poly)) (zerop (aref poly 0))))
  (let ((poly (coerce poly 'ntt-vector))
        (result-length (or result-length (length poly)))
        (res (make-array 1 :element-type 'ntt-int :initial-element 1)))
    (loop until (>= (length res) result-length)
          for new-len of-type (mod #.array-dimension-limit) = (* 2 (length res))
          for log = (poly-log res new-len)
          do (loop for i from 0 below (min new-len (length poly))
                   do (setf (aref log i)
                            (mod (- (aref poly i) (aref log i))
                                 +mod+)))
             (setf (aref log 0) (mod (+ 1 (aref log 0)) +mod+)
                   res (%adjust (poly-multiply res log) new-len)))
    (%adjust res result-length)))

(declaim (ftype (function * (values ntt-vector &optional)) poly-power))
(defun poly-power (poly exp &optional result-length)
  (declare (optimize (speed 3))
           (vector poly)
           ((integer 0 #.most-positive-fixnum) exp)
           ((or null (mod #.array-dimension-limit)) result-length))
  (let* ((poly (coerce poly 'ntt-vector))
         (result-length (or result-length (length poly)))
         (init-pos (position 0 poly :test-not #'eql)))
    (when (or (null init-pos)
              (zerop result-length)
              (> (* init-pos exp) result-length))
      (return-from poly-power
        (make-array result-length :element-type 'ntt-int :initial-element 0)))
    (let ((tmp (subseq poly init-pos)))
      (let ((inv (mod-inverse (aref poly init-pos) +mod+)))
        (dotimes (i (length tmp))
          (setf (aref tmp i) (mod (* (aref tmp i) inv) +mod+))))
      (setq tmp (poly-log tmp result-length))
      (let ((exp (mod exp +mod+)))
        (dotimes (i (length tmp))
          (setf (aref tmp i) (mod (* (aref tmp i) exp) +mod+)))
        (setq tmp (poly-exp tmp result-length)))
      (let ((power (mod-power (aref poly init-pos) exp +mod+)))
        (dotimes (i (length tmp))
          (setf (aref tmp i) (mod (* (aref tmp i) power) +mod+))))
      (if (zerop init-pos)
          tmp
          (let ((res (make-array result-length :element-type 'ntt-int :initial-element 0)))
            (replace res tmp :start1 (min (length res) (* init-pos exp)))
            res)))))

;; BEGIN_USE_PACKAGE
(eval-when (:compile-toplevel :load-toplevel :execute)
  (use-package :cp/polynomial-ntt :cl-user))
(eval-when (:compile-toplevel :load-toplevel :execute)
  (use-package :cp/geometric-sequence :cl-user))
(eval-when (:compile-toplevel :load-toplevel :execute)
  (use-package :cp/binom-mod-prime :cl-user))
(eval-when (:compile-toplevel :load-toplevel :execute)
  (use-package :cp/read-fixnum :cl-user))
(eval-when (:compile-toplevel :load-toplevel :execute)
  (use-package :cp/mod-operations :cl-user))
(in-package :cl-user)

;;;
;;; Body
;;;

(declaim ((simple-array uint31 (*)) *power/2*))
(sb-ext:define-load-time-global *power/2*
  (make-geometric-sequence (floor (+ +mod+ 1) 2) 51000 +mod+))

(declaim (inline calc))
(defun calc (x)
  (mod* (aref *fact* (* 2 x))
        (aref *power/2* x)
        (aref *fact-inv* x)))

(defun main ()
  (declare #.*opt*)
  (let* ((n (read))
         (counter (make-array 100000 :element-type 'uint31 :initial-element 0)))
    (declare (uint31 n))
    (dotimes (i (* 2 n))
      (let ((h (- (read-fixnum) 1)))
        (incf (aref counter h))))
    (let* ((counter (delete 0 counter))
           (len (length counter))
           (polys (make-array len :element-type t)))
      (dotimes (i len)
        (let* ((c (aref counter i))
               (deg (floor c 2))
               (poly (make-array (+ deg 1) :element-type 'uint31 :initial-element 0)))
          (dotimes (j (length poly))
            (setf (aref poly j)
                  (mod* (binom c (* 2 j)) (calc j))))
          (setf (aref polys i) poly)))
      (let ((dp (poly-total-prod polys))
            (res 0))
        (declare (uint31 res))
        (dotimes (i (length dp))
          (let ((val (mod* (calc (- n i)) (aref dp i))))
            (if (evenp i)
                (incfmod res val)
                (decfmod res val))))
        (println res)))))

#-swank (main)

;;;
;;; Test
;;;

#+swank
(progn
  (defparameter *lisp-file-pathname* (uiop:current-lisp-file-pathname))
  (setq *default-pathname-defaults* (uiop:pathname-directory-pathname *lisp-file-pathname*))
  (uiop:chdir *default-pathname-defaults*)
  (defparameter *dat-pathname* (uiop:merge-pathnames* "test.dat" *lisp-file-pathname*))
  (defparameter *problem-url* "https://atcoder.jp/contests/abl/tasks/abl_f"))

#+swank
(defun gen-dat ()
  (uiop:with-output-file (out *dat-pathname* :if-exists :supersede)
    (format out "")))

#+swank
(defun bench (&optional (out (make-broadcast-stream)))
  (time (run *dat-pathname* out)))

#+(and sbcl (not swank))
(eval-when (:compile-toplevel)
  (when sb-c::*undefined-warnings*
    (error "undefined warnings: ~{~A~^ ~}" sb-c::*undefined-warnings*)))

;; To run: (5am:run! :sample)
#+swank
(5am:test :sample
  (5am:is
   (equal "2
"
          (run "2
1
1
2
3
" nil)))
  (5am:is
   (equal "516
"
          (run "5
30
10
20
40
20
10
10
30
50
60
" nil))))

Submission Info

Submission Time
Task F - Heights and Pairs
User sansaqua
Language Common Lisp (SBCL 2.0.3)
Score 600
Code Size 53122 Byte
Status AC
Exec Time 110 ms
Memory 38652 KiB

Compile Error

; file: /imojudge/sandbox/Main.lisp
; in: DEFUN POLY-TOTAL-PROD
;     (REPLACE CP/POLYNOMIAL-NTT::DP CP/POLYNOMIAL-NTT::POLYS)
; 
; note: unable to
;   optimize
; due to type uncertainty:
;   The second argument is a VECTOR, not a SIMPLE-VECTOR.

; file: /imojudge/sandbox/Main.lisp
; in: DEFUN CHIRP-Z
;     (SUBSEQ
;      (CP/POLYNOMIAL-NTT:POLY-MULTIPLY CP/POLYNOMIAL-NTT::CS
;       CP/POLYNOMIAL-NTT::DS)
;      (- CP/POLYNOMIAL-NTT::N 1) (+ (- CP/POLYNOMIAL-NTT::N 1) LENGTH))
; --> LET* OR LET IF 
; ==>
;   LENGTH
; 
; note: deleting unreachable code

; file: /imojudge/sandbox/Main.lisp
; in: DEFUN BOSTAN-MORI
;     (> CP/POLYNOMIAL-NTT::INDEX 0)
; 
; note: forced to do FAST-IF->-ZERO (cost 8)
;       unable to do inline fixnum comparison (cost 3) because:
;       The first argument is a UNSIGNED-BYTE, not a FIXNUM.

;     (ASH CP/POLYNOMIAL-NTT::INDEX -1)
; 
; note: forced to do full call
;       unable to do inline ASH (cost 2) because:
;       The first argument is a (INTEGER 1), not a FIXNUM.
;       The result is a (VALUES UNSIGNED-BYTE &OPTIONAL), not a (VALUES FIXNUM
;                                                                       &REST T).
;       unable to do inline ASH (cost 3) because:
;       The first argument is a (INTEGER 1), not a (UNSIGNED-BYTE 64).
;       The result is a (VALUES UNSIGNED-BYTE &OPTIONAL), not a (VALUES
;                                                                (UNSIGNED-BYTE
;                                                                 64)
;                                                                &REST T).
;       etc.

;     (DEFUN CP/POLYNOMIAL-NTT:BOSTAN-MORI
;            (CP/POLYNOMIAL-NTT::INDEX CP/POLYNOMIAL-NTT::NUM
;             CP/POLYNOMIAL-NTT::DENOM)
;       "Returns [x^index](num(x)/denom(x)).
;   
;   Reference:
;   https://arxiv.org/abs/2008.08822
;   https://qiita.com/ryuhe1/items/da5acbcce4ac1911f47 (Japanese)"
;       (DECLARE (OPTIMIZE (SPEED 3))
;                (UNSIGNED-BYTE CP/POLYNOMIAL-NTT::INDEX)
;                (VECTOR CP/PO...

Judge Result

Set Name Sample All
Score / Max Score 0 / 0 600 / 600
Status
AC × 2
AC × 23
Set Name Test Cases
Sample example0.txt, example1.txt
All 000.txt, 001.txt, 002.txt, 003.txt, 004.txt, 005.txt, 006.txt, 007.txt, 008.txt, 009.txt, 010.txt, 011.txt, 012.txt, 013.txt, 014.txt, 015.txt, 016.txt, 017.txt, 018.txt, 019.txt, 020.txt, example0.txt, example1.txt
Case Name Status Exec Time Memory
000.txt AC 44 ms 36788 KiB
001.txt AC 40 ms 28672 KiB
002.txt AC 46 ms 29388 KiB
003.txt AC 45 ms 29828 KiB
004.txt AC 45 ms 30332 KiB
005.txt AC 53 ms 30288 KiB
006.txt AC 55 ms 30728 KiB
007.txt AC 59 ms 30948 KiB
008.txt AC 62 ms 31400 KiB
009.txt AC 64 ms 31840 KiB
010.txt AC 74 ms 32600 KiB
011.txt AC 78 ms 33616 KiB
012.txt AC 86 ms 34500 KiB
013.txt AC 88 ms 35028 KiB
014.txt AC 95 ms 35844 KiB
015.txt AC 95 ms 35948 KiB
016.txt AC 94 ms 36244 KiB
017.txt AC 110 ms 37748 KiB
018.txt AC 109 ms 38140 KiB
019.txt AC 109 ms 38652 KiB
020.txt AC 86 ms 37608 KiB
example0.txt AC 22 ms 28536 KiB
example1.txt AC 33 ms 28472 KiB