提出 #74603770


ソースコード 拡げる

package main

import (
	. "io"
	"math"
	"math/bits"
	"os"
	"reflect"
	"sort"
	"strconv"
)

type Integer interface {
	~int | ~int8 | ~int16 | ~int32 | ~int64 |
		~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64
}
type Realnumber interface {
	Integer |
		~float32 | ~float64
}
type Complexnumber interface {
	Realnumber |
		~complex64 | ~complex128
}

func Abs[T Realnumber](x T) T {
	if x < T(0) {
		return -x
	}
	return x
}
func Gcd[T Integer](x, y T) T {
	if x < 0 || y < 0 {
		return Gcd(Abs(x), Abs(y))
	}
	if y == 0 {
		return x
	}
	return Gcd(y, x%y)
}
func Lcm[T Integer](x, y T) T {
	return x / Gcd(x, y) * y
}
func Pow[S, T Integer](x S, n T, m S) S {
	r := S(1)
	for ; n > 0; n, x = n>>1, x*x%m {
		if n%2 == 1 {
			r = r * x % m
		}
	}
	return r
}
func Ctz[T Integer](x T) int {
	return bits.TrailingZeros(uint(x))
}
func Clz[T Integer](x T) int {
	return bits.LeadingZeros(uint(x))
}
func Popcount[T Integer](x T) int {
	return bits.OnesCount(uint(x))
}

type IO struct {
	in         Reader
	out        Writer
	rbuf, wbuf []byte
	i, n       int
	fpc        int
}

func NewIO(in Reader, out Writer) *IO {
	return &IO{in, out, make([]byte, 4096), make([]byte, 0), 0, 0, -1}
}
func NewStdIO() *IO {
	return NewIO(os.Stdin, os.Stdout)
}
func (io *IO) read_byte() byte {
	if io.i == io.n {
		io.n, _ = io.in.Read(io.rbuf)
		if io.n == 0 {
			return 0
		}
		io.i = 0
	}
	b := io.rbuf[io.i]
	io.i++
	return b
}
func (io *IO) Read(ptrs ...any) {
	var rd func(v reflect.Value)
	rd = func(v reflect.Value) {
		for i := 0; i < v.Len(); i++ {
			elem := v.Index(i)
			if elem.Kind() == reflect.Slice {
				rd(elem)
			} else {
				io.Read(elem.Addr().Interface())
			}
		}
	}
	for _, p := range ptrs {
		switch v := any(p).(type) {
		case *uint:
			{
				var x uint
				b := io.read_byte()
				for ; '0' > b || b > '9'; b = io.read_byte() {
					if b == 0 {
						return
					}
				}
				for ; '0' <= b && b <= '9'; b = io.read_byte() {
					x = x*10 + uint(b&15)
				}
				*v = x
			}
		case *uint8:
			{
				var x uint8
				b := io.read_byte()
				for ; '0' > b || b > '9'; b = io.read_byte() {
					if b == 0 {
						return
					}
				}
				for ; '0' <= b && b <= '9'; b = io.read_byte() {
					x = x*10 + uint8(b&15)
				}
				*v = x
			}
		case *uint16:
			{
				var x uint16
				b := io.read_byte()
				for ; '0' > b || b > '9'; b = io.read_byte() {
					if b == 0 {
						return
					}
				}
				for ; '0' <= b && b <= '9'; b = io.read_byte() {
					x = x*10 + uint16(b&15)
				}
				*v = x
			}
		case *uint32:
			{
				var x uint32
				b := io.read_byte()
				for ; '0' > b || b > '9'; b = io.read_byte() {
					if b == 0 {
						return
					}
				}
				for ; '0' <= b && b <= '9'; b = io.read_byte() {
					x = x*10 + uint32(b&15)
				}
				*v = x
			}
		case *uint64:
			{
				var x uint64
				b := io.read_byte()
				for ; '0' > b || b > '9'; b = io.read_byte() {
					if b == 0 {
						return
					}
				}
				for ; '0' <= b && b <= '9'; b = io.read_byte() {
					x = x*10 + uint64(b&15)
				}
				*v = x
			}
		case *int:
			{
				neg := false
				b := io.read_byte()
				for ; '0' > b || b > '9'; b = io.read_byte() {
					if b == 0 {
						return
					}
					if b == '-' {
						neg = true
					}
				}
				var y uint
				var x int
				for ; '0' <= b && b <= '9'; b = io.read_byte() {
					y = y*10 + uint(b&15)
				}
				if neg {
					if y == math.MaxInt+1 {
						x = math.MinInt
					} else {
						x = -int(y)
					}
				} else {
					x = int(y)
				}
				*v = x
			}
		case *int8:
			{
				neg := false
				b := io.read_byte()
				for ; '0' > b || b > '9'; b = io.read_byte() {
					if b == 0 {
						return
					}
					if b == '-' {
						neg = true
					}
				}
				var x int8
				for ; '0' <= b && b <= '9'; b = io.read_byte() {
					x = x*10 + int8(b&15)
				}
				if neg {
					x = -x
				}
				*v = x
			}
		case *int16:
			{
				neg := false
				b := io.read_byte()
				for ; '0' > b || b > '9'; b = io.read_byte() {
					if b == 0 {
						return
					}
					if b == '-' {
						neg = true
					}
				}
				var x int16
				for ; '0' <= b && b <= '9'; b = io.read_byte() {
					x = x*10 + int16(b&15)
				}
				if neg {
					x = -x
				}
				*v = x
			}
		case *int32:
			{
				neg := false
				b := io.read_byte()
				for ; '0' > b || b > '9'; b = io.read_byte() {
					if b == 0 {
						return
					}
					if b == '-' {
						neg = true
					}
				}
				var x int32
				for ; '0' <= b && b <= '9'; b = io.read_byte() {
					x = x*10 + int32(b&15)
				}
				if neg {
					x = -x
				}
				*v = x
			}
		case *int64:
			{
				neg := false
				b := io.read_byte()
				for ; '0' > b || b > '9'; b = io.read_byte() {
					if b == 0 {
						return
					}
					if b == '-' {
						neg = true
					}
				}
				var y uint
				var x int64
				for ; '0' <= b && b <= '9'; b = io.read_byte() {
					y = y*10 + uint(b&15)
				}
				if neg {
					if y == math.MaxInt64+1 {
						x = math.MinInt64
					} else {
						x = -int64(y)
					}
				} else {
					x = int64(y)
				}
				*v = x
			}
		case *float32:
			{
				b := io.read_byte()
				var s []byte
				for ; b == ' ' || b == '\n' || b == '\r' || b == '\t'; b = io.read_byte() {
				}
				for ; !(b == ' ' || b == '\n' || b == '\r' || b == '\t' || b == 0); b = io.read_byte() {
					s = append(s, b)
				}
				w, _ := strconv.ParseFloat(string(s), 32)
				*v = float32(w)
			}
		case *float64:
			{
				b := io.read_byte()
				var s []byte
				for ; b == ' ' || b == '\n' || b == '\r' || b == '\t'; b = io.read_byte() {
				}
				for ; !(b == ' ' || b == '\n' || b == '\r' || b == '\t' || b == 0); b = io.read_byte() {
					s = append(s, b)
				}
				w, _ := strconv.ParseFloat(string(s), 64)
				*v = w
			}
		case *string:
			{
				b := io.read_byte()
				var s []byte
				for ; b == ' ' || b == '\n' || b == '\r'; b = io.read_byte() {
				}
				for ; !(b == ' ' || b == '\n' || b == '\r' || b == 0); b = io.read_byte() {
					s = append(s, b)
				}
				*v = string(s)
			}
		default:
			rv := reflect.ValueOf(p)
			if rv.Kind() == reflect.Ptr && (rv.Elem().Kind() == reflect.Slice || rv.Elem().Kind() == reflect.Array) {
				rd(rv.Elem())
			}
		}
	}
}
func (io *IO) Write(a ...any) {
	uitos := func(v uint64) []byte {
		var s []byte
		if v == 0 {
			return []byte{'0'}
		}
		for v > 0 {
			s = append(s, '0'+byte(v%10))
			v /= 10
		}
		for i := 0; i < len(s)/2; i++ {
			s[i], s[len(s)-1-i] = s[len(s)-1-i], s[i]
		}
		return s
	}
	itos := func(v int64) []byte {
		if v == 0 {
			return []byte{'0'}
		}
		if v == math.MinInt64 {
			return []byte("-9223372036854775808")
		}
		neg := v < 0
		if neg {
			v = -v
		}
		var s []byte
		for v > 0 {
			s = append(s, '0'+byte(v%10))
			v /= 10
		}
		if neg {
			s = append(s, '-')
		}
		for i := 0; i < len(s)/2; i++ {
			s[i], s[len(s)-1-i] = s[len(s)-1-i], s[i]
		}
		return s
	}

	for i, p := range a {
		if i != 0 {
			io.wbuf = append(io.wbuf, ' ')
		}
		switch v := p.(type) {
		case uint:
			io.wbuf = append(io.wbuf, uitos(uint64(v))...)
		case uint8:
			io.wbuf = append(io.wbuf, uitos(uint64(v))...)
		case uint16:
			io.wbuf = append(io.wbuf, uitos(uint64(v))...)
		case uint32:
			io.wbuf = append(io.wbuf, uitos(uint64(v))...)
		case uint64:
			io.wbuf = append(io.wbuf, uitos(v)...)

		case int:
			io.wbuf = append(io.wbuf, itos(int64(v))...)
		case int8:
			io.wbuf = append(io.wbuf, itos(int64(v))...)
		case int16:
			io.wbuf = append(io.wbuf, itos(int64(v))...)
		case int32:
			io.wbuf = append(io.wbuf, itos(int64(v))...)
		case int64:
			io.wbuf = append(io.wbuf, itos(v)...)

		case float32:
			io.wbuf = append(io.wbuf, []byte(strconv.FormatFloat(float64(v), 'f', io.fpc, 64))...)
		case float64:
			io.wbuf = append(io.wbuf, []byte(strconv.FormatFloat(v, 'f', io.fpc, 64))...)
		case string:
			io.wbuf = append(io.wbuf, v...)
		default:
			rv := reflect.ValueOf(p)
			if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array {
				if rv.Type().Elem().Kind() == reflect.Slice {
					for j := 0; j < rv.Len(); j++ {
						if j+1 == rv.Len() {
							io.Write(rv.Index(j).Interface())
						} else {
							io.Writeln(rv.Index(j).Interface())
						}
					}
				} else {
					for j := 0; j < rv.Len(); j++ {
						if j != 0 {
							io.wbuf = append(io.wbuf, ' ')
						}
						io.Write(rv.Index(j).Interface())
					}
				}
			}
		}
	}
}
func (io *IO) SetPrecision(x int) {
	io.fpc = x
}

func (io *IO) Writeln(a ...any) {
	io.Write(a...)
	io.Write("\n")
}

func (io *IO) Flush() {
	io.out.Write(io.wbuf)
	io.wbuf = io.wbuf[:0]
}

func (io *IO) Close() {
	io.Flush()
}

type MedianFinderWaveletMatrix struct {
	Wm *WaveletMatrixWithSum
}

// WaveletMatrix 维护区间中位数信息.
// `Proxy 的内部持有一个对 WaveletMatrix 的引用`.
func NewMedianFinderWaveletMatrix(wm *WaveletMatrixWithSum) *MedianFinderWaveletMatrix {
	return &MedianFinderWaveletMatrix{Wm: wm}
}

// upper: 如果有两个中位数,返回较大的那个.
func (mf *MedianFinderWaveletMatrix) Median(upper bool) int {
	return mf.MedianRange(0, mf.Wm.n, upper)
}

// 返回区间 [st, end) 中的中位数.
// upper: 如果有两个中位数,返回较大的那个.
func (mf *MedianFinderWaveletMatrix) MedianRange(st, end int32, upper bool) int {
	if st < 0 {
		st = 0
	}
	if end > mf.Wm.n {
		end = mf.Wm.n
	}
	if st >= end {
		return 0
	}
	return mf.Wm.Median(st, end, upper, 0)
}

func (mf *MedianFinderWaveletMatrix) DistSum(to int) int {
	return mf.DistSumRange(to, 0, mf.Wm.n)
}

func (mf *MedianFinderWaveletMatrix) DistSumRange(to int, st, end int32) int {
	if st < 0 {
		st = 0
	}
	if end > mf.Wm.n {
		end = mf.Wm.n
	}
	if st >= end {
		return 0
	}
	m := end - st
	lowerCount, lowerSum := mf.Wm.RangeCountAndSum(st, end, -INF, WmValue(to), 0) // bisect_left
	allSum := mf.Wm.SumAll(st, end)
	if lowerCount == 0 {
		return allSum - int(m)*to
	}
	if lowerCount == m {
		return int(m)*to - allSum
	}
	upperSum := allSum - lowerSum
	leftSum := to*int(lowerCount) - lowerSum
	rightSum := upperSum - to*int(m-lowerCount)
	return leftSum + rightSum
}

func (mf *MedianFinderWaveletMatrix) DistSumToMedian() int {
	return mf.DistSumToMedianRange(0, mf.Wm.n)
}

func (mf *MedianFinderWaveletMatrix) DistSumToMedianRange(st, end int32) int {
	if st < 0 {
		st = 0
	}
	if end > mf.Wm.n {
		end = mf.Wm.n
	}
	if st >= end {
		return 0
	}
	m := end - st
	count1 := m / 2
	count2 := m - count1
	mid, sum1 := mf.Wm.KthValueAndSum(st, end, count1, 0)
	allSum := mf.Wm.SumAll(st, end)
	sum2 := allSum - sum1
	res := 0
	res += mid*int(count1) - sum1
	res += sum2 - mid*int(count2)
	return res
}

const INF WmValue = 1e18

type WmValue = int
type WmSum = int

func (*WaveletMatrixWithSum) e() WmSum            { return 0 }
func (*WaveletMatrixWithSum) op(a, b WmSum) WmSum { return a + b }
func (*WaveletMatrixWithSum) inv(a WmSum) WmSum   { return -a }

type WaveletMatrixWithSum struct {
	n, log   int32
	setLog   bool
	compress bool
	useSum   bool
	mid      []int32
	bv       []*BitVector
	key      []WmValue
	presum   [][]WmSum
}

// nums: 数组元素.
// sumData: 和数据,nil表示不需要和数据.
// log: 如果需要支持异或查询则需要传入log,-1表示默认.
// compress: 是否对nums进行离散化(值域较大(1e9)时可以离散化加速).
func NewWaveletMatrixWithSum(nums []WmValue, sumData []WmSum, log int32, compress bool) *WaveletMatrixWithSum {
	wm := &WaveletMatrixWithSum{}
	wm.build(nums, sumData, log, compress)
	return wm
}

func (wm *WaveletMatrixWithSum) build(nums []WmValue, sumData []WmSum, log int32, compress bool) {
	numsCopy := append(nums[:0:0], nums...)
	sumDataCopy := append(sumData[:0:0], sumData...)

	wm.n = int32(len(numsCopy))
	wm.log = log
	wm.setLog = log != -1
	wm.compress = compress
	wm.useSum = len(sumData) > 0
	if wm.n == 0 {
		wm.log = 0
		wm.presum = [][]WmSum{{wm.e()}}
		return
	}

	if compress {
		if wm.setLog {
			panic("compress and log should not be set at the same time")
		}
		wm.key = make([]WmValue, 0, wm.n)
		order := wm._argSort(numsCopy)
		for _, i := range order {
			if len(wm.key) == 0 || wm.key[len(wm.key)-1] != numsCopy[i] {
				wm.key = append(wm.key, numsCopy[i])
			}
			numsCopy[i] = WmValue(len(wm.key) - 1)
		}
		wm.key = wm.key[:len(wm.key):len(wm.key)]
	}
	if wm.log == -1 {
		tmp := wm._maxs(numsCopy)
		if tmp < 1 {
			tmp = 1
		}
		wm.log = int32(bits.Len(uint(tmp)))
	}
	wm.mid = make([]int32, wm.log)
	wm.bv = make([]*BitVector, wm.log)
	for i := range wm.bv {
		wm.bv[i] = NewBitVector(wm.n)
	}
	if wm.useSum {
		wm.presum = make([][]WmSum, 1+wm.log)
		for i := range wm.presum {
			sums := make([]WmSum, wm.n+1)
			for j := range sums {
				sums[j] = wm.e()
			}
			wm.presum[i] = sums
		}
	}
	if len(sumDataCopy) == 0 {
		sumDataCopy = make([]WmSum, len(numsCopy))
	}

	A, S := numsCopy, sumDataCopy
	A0, A1 := make([]WmValue, wm.n), make([]WmValue, wm.n)
	S0, S1 := make([]WmSum, wm.n), make([]WmSum, wm.n)
	for d := wm.log - 1; d >= -1; d-- {
		p0, p1 := int32(0), int32(0)
		if wm.useSum {
			tmp := wm.presum[d+1]
			for i := int32(0); i < wm.n; i++ {
				tmp[i+1] = wm.op(tmp[i], S[i])
			}
		}
		if d == -1 {
			break
		}
		for i := int32(0); i < wm.n; i++ {
			f := (A[i] >> d & 1) == 1
			if !f {
				if wm.useSum {
					S0[p0] = S[i]
				}
				A0[p0] = A[i]
				p0++
			} else {
				if wm.useSum {
					S1[p1] = S[i]
				}
				wm.bv[d].Set(i)
				A1[p1] = A[i]
				p1++
			}
		}
		wm.mid[d] = p0
		wm.bv[d].Build()
		A, A0 = A0, A
		S, S0 = S0, S
		for i := int32(0); i < p1; i++ {
			A[p0+i] = A1[i]
			S[p0+i] = S1[i]
		}
	}
}

// 返回区间 [st, end) 中 值在 [a, b) 中的元素个数以及这些元素的和.
func (wm *WaveletMatrixWithSum) RangeCountAndSum(st, end int32, a, b WmValue, xorValue WmValue) (int32, WmSum) {
	if xorValue != 0 {
		if !wm.setLog {
			panic("log should be set when xor is used")
		}
	}
	if st < 0 {
		st = 0
	}
	if end > wm.n {
		end = wm.n
	}
	if st >= end || a >= b {
		return 0, wm.e()
	}
	if wm.compress {
		a = wm._lowerBound(wm.key, a)
		b = wm._lowerBound(wm.key, b)
	}
	count, sum := int32(0), wm.e()
	var dfs func(d, l, r int32, lx, rx WmValue)
	dfs = func(d, l, r int32, lx, rx WmValue) {
		if rx <= a || b <= lx {
			return
		}
		if a <= lx && rx <= b {
			count += r - l
			if wm.useSum {
				sum = wm.op(sum, wm._get(d, l, r))
			}
			return
		}
		d--
		mx := (lx + rx) >> 1
		l0, r0 := wm.bv[d].Rank(l, false), wm.bv[d].Rank(r, false)
		l1, r1 := l+wm.mid[d]-l0, r+wm.mid[d]-r0
		if xorValue>>d&1 == 1 {
			l0, l1 = l1, l0
			r0, r1 = r1, r0
		}
		dfs(d, l0, r0, lx, mx)
		dfs(d, l1, r1, mx, rx)
	}
	dfs(wm.log, st, end, 0, 1<<wm.log)
	return count, sum
}

// 返回区间 [st, end) 中的 (第k小的元素, 前k个元素(不包括第k小的元素) 的 op 的结果).
// 如果k >= end-st, 返回 (-1, 区间 op 的结果).
func (wm *WaveletMatrixWithSum) KthValueAndSum(st, end, k int32, xorVal WmValue) (WmValue, WmSum) {
	if st < 0 {
		st = 0
	}
	if end > wm.n {
		end = wm.n
	}
	if st >= end {
		return -1, wm.e()
	}
	if k >= end-st {
		return -1, wm.SumAll(st, end)
	}
	if xorVal != 0 {
		if !wm.setLog {
			panic("log should be set when xor is used")
		}
	}

	sum, val := wm.e(), WmValue(0)
	for d := wm.log - 1; d >= 0; d-- {
		l0, r0 := wm.bv[d].Rank(st, false), wm.bv[d].Rank(end, false)
		l1, r1 := st+wm.mid[d]-l0, end+wm.mid[d]-r0
		if (xorVal>>d)&1 == 1 {
			l0, l1 = l1, l0
			r0, r1 = r1, r0
		}
		if k < r0-l0 {
			st, end = l0, r0
		} else {
			k -= r0 - l0
			val |= 1 << d
			st, end = l1, r1
			if wm.useSum {
				sum = wm.op(sum, wm._get(d, l0, r0))
			}
		}
	}
	if wm.useSum {
		sum = wm.op(sum, wm._get(0, st, st+k))
	}
	if wm.compress {
		val = wm.key[val]
	}
	return val, sum
}

// [st, end)区间内第k(k>=0)小的元素.
func (wm *WaveletMatrixWithSum) Kth(st, end, k int32, xorVal WmValue) WmValue {
	if k < 0 {
		k = 0
	}
	if n := end - st - 1; k > n {
		k = n
	}
	v, _ := wm.KthValueAndSum(st, end, k, xorVal)
	return v
}

// upper: 向上取中位数还是向下取中位数.
func (wm *WaveletMatrixWithSum) Median(st, end int32, upper bool, xorVal WmValue) WmValue {
	n := end - st
	var k int32
	if upper {
		k = n >> 1
	} else {
		k = (n - 1) >> 1
	}
	return wm.Kth(st, end, k, xorVal)
}

// [st, end) 中小于等于 x 的数中最大的数.
//
//	如果不存在则返回-INF.
func (wm *WaveletMatrixWithSum) Floor(st, end int32, x WmValue, xor WmValue) WmValue {
	if xor != 0 {
		if !wm.setLog {
			panic("log should be set when xor is used")
		}
	}
	if st < 0 {
		st = 0
	}
	if end > wm.n {
		end = wm.n
	}
	if st >= end {
		return -INF
	}
	res := -INF
	x++
	if wm.compress {
		x = wm._lowerBound(wm.key, x)
	}
	var dfs func(d, l, r int32, lx, rx WmValue)
	dfs = func(d, l, r int32, lx, rx WmValue) {
		if rx-1 <= res || l == r || x <= lx {
			return
		}
		if d == 0 {
			res = max(res, lx)
			return
		}
		d--
		mx := (lx + rx) >> 1
		l0, r0 := wm.bv[d].Rank(l, false), wm.bv[d].Rank(r, false)
		l1, r1 := l+wm.mid[d]-l0, r+wm.mid[d]-r0
		if xor>>d&1 == 1 {
			l0, l1 = l1, l0
			r0, r1 = r1, r0
		}
		dfs(d, l1, r1, mx, rx)
		dfs(d, l0, r0, lx, mx)
	}
	dfs(wm.log, st, end, 0, 1<<wm.log)
	if wm.compress && res != -INF {
		res = wm.key[res]
	}
	return res
}

// [st, end) 中大于等于 x 的数中最小的数
//
//	如果不存在则返回INF
func (wm *WaveletMatrixWithSum) Ceil(st, end int32, x WmValue, xor WmValue) int {
	if xor != 0 {
		if !wm.setLog {
			panic("log should be set when xor is used")
		}
	}
	if st < 0 {
		st = 0
	}
	if end > wm.n {
		end = wm.n
	}
	if st >= end {
		return INF
	}
	if wm.compress {
		x = wm._lowerBound(wm.key, x)
	}
	res := INF
	var dfs func(d, l, r int32, lx, rx WmValue)
	dfs = func(d, l, r int32, lx, rx WmValue) {
		if res <= lx || l == r || rx <= x {
			return
		}
		if d == 0 {
			res = min(res, lx)
			return
		}
		d--
		mx := (lx + rx) >> 1
		l0, r0 := wm.bv[d].Rank(l, false), wm.bv[d].Rank(r, false)
		l1, r1 := l+wm.mid[d]-l0, r+wm.mid[d]-r0
		if xor>>d&1 == 1 {
			l0, l1 = l1, l0
			r0, r1 = r1, r0
		}
		dfs(d, l0, r0, lx, mx)
		dfs(d, l1, r1, mx, rx)
	}
	dfs(wm.log, st, end, 0, 1<<wm.log)
	if wm.compress && res < INF {
		res = wm.key[res]
	}
	return res
}

// 返回区间 [st, end) 中 范围在 [a, b) 中的元素的和.
func (wm *WaveletMatrixWithSum) SumRange(st, end int32, a, b WmValue, xorVal WmValue) WmSum {
	if !wm.useSum {
		panic("sum data must be provided")
	}
	if st < 0 {
		st = 0
	}
	if end > wm.n {
		end = wm.n
	}
	if st >= end || a >= b {
		return wm.e()
	}
	_, sum := wm.RangeCountAndSum(st, end, a, b, xorVal)
	return sum
}

// 返回区间 [st, end) 中 排名在 [k1, k2) 中的元素的和.
func (wm *WaveletMatrixWithSum) SumSlice(st, end, k1, k2 int32, xorVal WmValue) WmSum {
	if !wm.useSum {
		panic("sum data must be provided")
	}
	if k1 < 0 {
		k1 = 0
	}
	if k2 > end-st {
		k2 = end - st
	}
	if st < 0 {
		st = 0
	}
	if end > wm.n {
		end = wm.n
	}
	if st >= end || k1 >= k2 {
		return wm.e()
	}
	_, sum1 := wm.KthValueAndSum(st, end, k1, xorVal)
	_, sum2 := wm.KthValueAndSum(st, end, k2, xorVal)
	return wm.op(sum2, wm.inv(sum1))
}

func (wm *WaveletMatrixWithSum) SumAll(st, end int32) WmSum {
	if st < 0 {
		st = 0
	}
	if end > wm.n {
		end = wm.n
	}
	if st >= end {
		return wm.e()
	}
	return wm._get(wm.log, st, end)
}

// 使得predicate(count, sum)为true的最大的(count, sum).
func (wm *WaveletMatrixWithSum) MaxRight(predicate func(int32, WmSum) bool, st, end int32, xorVal WmValue) (int32, WmSum) {
	if xorVal != 0 {
		if !wm.setLog {
			panic("log should be set when xor is used")
		}
	}
	if st >= end {
		return 0, wm.e()
	}
	if s := wm._get(wm.log, st, end); predicate(end-st, s) {
		return end - st, s
	}
	count, sum := int32(0), wm.e()
	for d := wm.log - 1; d >= 0; d-- {
		l0, r0 := wm.bv[d].Rank(st, false), wm.bv[d].Rank(end, false)
		l1, r1 := st+wm.mid[d]-l0, end+wm.mid[d]-r0
		if xorVal>>d&1 == 1 {
			l0, l1 = l1, l0
			r0, r1 = r1, r0
		}
		if s := wm.op(sum, wm._get(d, l0, r0)); predicate(count+r0-l0, s) {
			count += r0 - l0
			sum = s
			st, end = l1, r1
		} else {
			st, end = l0, r0
		}
	}
	k := wm._binarySearch(func(k int32) bool {
		return predicate(count+k, wm.op(sum, wm._get(0, st, st+k)))
	}, 0, end-st)
	count += k
	sum = wm.op(sum, wm._get(0, st, st+k))
	return count, sum
}

func (wm *WaveletMatrixWithSum) _get(d, l, r int32) WmSum {
	if wm.useSum {
		return wm.op(wm.presum[d][r], wm.inv(wm.presum[d][l]))
	}
	return wm.e()
}

func (wm *WaveletMatrixWithSum) _argSort(nums []WmValue) []int32 {
	order := make([]int32, len(nums))
	for i := range order {
		order[i] = int32(i)
	}
	sort.Slice(order, func(i, j int) bool { return nums[order[i]] < nums[order[j]] })
	return order
}

func (wm *WaveletMatrixWithSum) _maxs(nums []WmValue) WmValue {
	res := nums[0]
	for _, v := range nums {
		if v > res {
			res = v
		}
	}
	return res
}

func (wm *WaveletMatrixWithSum) _lowerBound(nums []WmValue, target WmValue) WmValue {
	left, right := int32(0), int32(len(nums)-1)
	for left <= right {
		mid := (left + right) >> 1
		if nums[mid] < target {
			left = mid + 1
		} else {
			right = mid - 1
		}
	}
	return WmValue(left)
}

func (wm *WaveletMatrixWithSum) _binarySearch(f func(int32) bool, ok, ng int32) int32 {
	for abs32(ok-ng) > 1 {
		x := (ok + ng) >> 1
		if f(x) {
			ok = x
		} else {
			ng = x
		}
	}
	return ok
}

// RangeDistinctCount 返回区间 [st, end) 中不同元素的个数
func (wm *WaveletMatrixWithSum) RangeDistinctCount(st, end int32) int32 {
	if st < 0 {
		st = 0
	}
	if end > wm.n {
		end = wm.n
	}
	if st >= end {
		return 0
	}

	count := int32(0)
	var dfs func(d, l, r int32, lx, rx WmValue)
	dfs = func(d, l, r int32, lx, rx WmValue) {
		if l == r {
			return
		}
		if d == 0 {
			// 叶子节点,表示一个具体的值
			if r > l {
				count++
			}
			return
		}
		d--
		l0, r0 := wm.bv[d].Rank(l, false), wm.bv[d].Rank(r, false)
		l1, r1 := l+wm.mid[d]-l0, r+wm.mid[d]-r0

		// 递归处理左子树(0位)和右子树(1位)
		dfs(d, l0, r0, lx, (lx+rx)>>1)
		dfs(d, l1, r1, (lx+rx)>>1, rx)
	}
	dfs(wm.log, st, end, 0, 1<<wm.log)
	return count
}

// RangeDistinctValues 返回区间 [st, end) 中的所有不同值
func (wm *WaveletMatrixWithSum) RangeDistinctValues(st, end int32) []WmValue {
	if st < 0 {
		st = 0
	}
	if end > wm.n {
		end = wm.n
	}
	if st >= end {
		return []WmValue{}
	}

	result := []WmValue{}
	var dfs func(d, l, r int32, lx, rx WmValue, value WmValue)
	dfs = func(d, l, r int32, lx, rx WmValue, value WmValue) {
		if l == r {
			return
		}
		if d == 0 {
			// 叶子节点,找到一个不同的值
			if r > l {
				if wm.compress {
					result = append(result, wm.key[value])
				} else {
					result = append(result, value)
				}
			}
			return
		}
		d--
		l0, r0 := wm.bv[d].Rank(l, false), wm.bv[d].Rank(r, false)
		l1, r1 := l+wm.mid[d]-l0, r+wm.mid[d]-r0

		// 递归处理左子树(当前位为0)
		dfs(d, l0, r0, lx, (lx+rx)>>1, value)
		// 递归处理右子树(当前位为1)
		dfs(d, l1, r1, (lx+rx)>>1, rx, value|(1<<d))
	}
	dfs(wm.log, st, end, 0, 1<<wm.log, 0)
	return result
}

// RangeDistinctCountInRange 返回区间 [st, end) 中值在 [a, b) 内的不同元素个数
func (wm *WaveletMatrixWithSum) RangeDistinctCountInRange(st, end int32, a, b WmValue) int32 {
	if st < 0 {
		st = 0
	}
	if end > wm.n {
		end = wm.n
	}
	if st >= end || a >= b {
		return 0
	}

	if wm.compress {
		a = wm._lowerBound(wm.key, a)
		b = wm._lowerBound(wm.key, b)
	}

	count := int32(0)
	var dfs func(d, l, r int32, lx, rx WmValue)
	dfs = func(d, l, r int32, lx, rx WmValue) {
		if l == r || rx <= a || b <= lx {
			return
		}
		if a <= lx && rx <= b {
			// 整个区间都在范围内,只需要判断是否存在元素
			if d == 0 && r > l {
				count++
			} else if d > 0 && r > l {
				// 非叶子节点,存在至少一个元素
				count++
			}
			return
		}
		d--
		l0, r0 := wm.bv[d].Rank(l, false), wm.bv[d].Rank(r, false)
		l1, r1 := l+wm.mid[d]-l0, r+wm.mid[d]-r0

		mx := (lx + rx) >> 1
		dfs(d, l0, r0, lx, mx)
		dfs(d, l1, r1, mx, rx)
	}
	dfs(wm.log, st, end, 0, 1<<wm.log)
	return count
}

// RangeXorMax 返回区间 [st, end) 中与 xorVal 异或结果最大的值。
// 如果区间为空,返回 -1 (或自定义的最小值)。
func (wm *WaveletMatrixWithSum) RangeXorMax(st, end int32, xorVal WmValue) WmValue {
	if st >= end {
		return -1
	}
	if wm.compress {
		panic("RangeXorMax is not supported when compress is true")
	}
	res := WmValue(0)
	for d := wm.log - 1; d >= 0; d-- {
		l0, r0 := wm.bv[d].Rank(st, false), wm.bv[d].Rank(end, false)
		count0 := r0 - l0
		targetBit := (xorVal >> d) & 1
		moveRight := false
		if targetBit == 0 {
			// 我们想要 1
			if (end-st)-count0 > 0 {
				moveRight = true
			}
		} else {
			// 我们想要 0
			if count0 > 0 {
				moveRight = false
			} else {
				moveRight = true
			}
		}

		if moveRight {
			// 走向 1 分支
			l1, r1 := st+wm.mid[d]-l0, end+wm.mid[d]-r0
			res |= (1 << d)
			st, end = l1, r1
		} else {
			// 走向 0 分支
			st, end = l0, r0
		}
	}
	return res
}

type BitVector struct {
	bits   []uint64
	preSum []int32
}

func NewBitVector(n int32) *BitVector {
	return &BitVector{bits: make([]uint64, n>>6+1), preSum: make([]int32, n>>6+1)}
}

func (bv *BitVector) Set(i int32) {
	bv.bits[i>>6] |= 1 << (i & 63)
}

func (bv *BitVector) Build() {
	for i := 0; i < len(bv.bits)-1; i++ {
		bv.preSum[i+1] = bv.preSum[i] + int32(bits.OnesCount64(bv.bits[i]))
	}
}

func (bv *BitVector) Rank(k int32, f bool) int32 {
	m, s := bv.bits[k>>6], bv.preSum[k>>6]
	res := s + int32(bits.OnesCount64(m&((1<<(k&63))-1)))
	if f {
		return res
	}
	return k - res
}

func abs32(x int32) int32 {
	if x < 0 {
		return -x
	}
	return x
}

func min32(a, b int32) int32 {
	if a < b {
		return a
	}
	return b
}

var io = NewStdIO()

const MOD = 998244353

func solve() {
	var n, m, q, k int
	io.Read(&n, &m, &q, &k)
	a := make([][2]int, m)
	for i := 0; i < m; i++ {
		io.Read(&a[i][0], &a[i][1])
	}
	sort.Slice(a, func(i, j int) bool {
		return a[i][0] < a[j][0]
	})
	A := make([]int, m)
	L := make([]int, n+1)
	R := make([]int, n+1)
	for i := range L {
		L[i] = -1
		R[i] = -1
	}
	cur := -1
	st := -1
	for i := 0; i < m; i++ {
		A[i] = a[i][1]
		if a[i][0] != cur {
			if cur != -1 {
				R[cur] = i - 1
			}
			cur = a[i][0]
			st = i
			L[cur] = st
		}
		if i == m-1 {
			R[cur] = i
		}
	}
	vis := make([]int, 0)
	for i := 1; i <= n; i++ {
		if L[i] != -1 {
			vis = append(vis, i)
		}
	}
	wm := NewWaveletMatrixWithSum(A, A, -1, true)
	for i := 0; i < q; i++ {
		var l, r, t int
		io.Read(&l, &r, &t)
		lidx := sort.Search(len(vis), func(i int) bool {
			return vis[i] >= l
		})
		ridx := sort.Search(len(vis), func(i int) bool {
			return vis[i] > r
		}) - 1
		if lidx > ridx {
			io.Writeln(0)
			continue
		}
		left := L[vis[lidx]]
		right := R[vis[ridx]]
		d, _ := wm.RangeCountAndSum(int32(left), int32(right+1), t, 1<<32, 0)
		io.Writeln(max(0, int(d)-k))
	}
}
func main() {
	var t int
	t = 1
	// io.Read(&t)
	for ; t > 0; t-- {
		solve()
	}
	io.Close()
}

提出情報

提出日時
問題 E - 図書館の蔵書検索
ユーザ xiaoe
言語 Go (go 1.25.1)
得点 466
コード長 28179 Byte
結果 AC
実行時間 166 ms
メモリ 30464 KiB

ジャッジ結果

セット名 Sample All
得点 / 配点 0 / 0 466 / 466
結果
AC × 5
AC × 94
セット名 テストケース
Sample sample01.txt, sample02.txt, sample03.txt, sample04.txt, sample05.txt
All sample01.txt, sample02.txt, sample03.txt, sample04.txt, sample05.txt, in01.txt, in02.txt, in03.txt, in04.txt, in05.txt, in06.txt, in07.txt, in08.txt, in09.txt, in10.txt, in11.txt, in12.txt, in13.txt, in14.txt, in15.txt, in16.txt, in17.txt, in18.txt, in19.txt, in20.txt, in21.txt, in22.txt, in23.txt, in24.txt, in25.txt, in26.txt, in27.txt, in28.txt, in29.txt, in30.txt, in31.txt, in32.txt, in33.txt, in34.txt, in35.txt, in36.txt, in37.txt, in38.txt, in39.txt, in40.txt, in41.txt, in42.txt, in43.txt, in44.txt, in45.txt, in46.txt, in47.txt, in48.txt, in49.txt, in50.txt, in51.txt, in52.txt, in53.txt, in54.txt, in55.txt, in56.txt, in57.txt, in58.txt, in59.txt, in60.txt, in61.txt, in62.txt, in63.txt, in64.txt, in65.txt, in66.txt, in67.txt, in68.txt, in69.txt, in70.txt, in71.txt, in72.txt, in73.txt, in74.txt, in75.txt, in76.txt, in77.txt, in78.txt, in79.txt, in80.txt, in81.txt, in82.txt, in83.txt, in84.txt, in85.txt, in86.txt, in87.txt, in88.txt, in89.txt
ケース名 結果 実行時間 メモリ
in01.txt AC 1 ms 1664 KiB
in02.txt AC 1 ms 1664 KiB
in03.txt AC 0 ms 1664 KiB
in04.txt AC 0 ms 1664 KiB
in05.txt AC 0 ms 1664 KiB
in06.txt AC 0 ms 1664 KiB
in07.txt AC 0 ms 1664 KiB
in08.txt AC 0 ms 1664 KiB
in09.txt AC 1 ms 1664 KiB
in10.txt AC 159 ms 27008 KiB
in11.txt AC 157 ms 27008 KiB
in12.txt AC 160 ms 29440 KiB
in13.txt AC 61 ms 24704 KiB
in14.txt AC 56 ms 26496 KiB
in15.txt AC 159 ms 27008 KiB
in16.txt AC 75 ms 30464 KiB
in17.txt AC 39 ms 10880 KiB
in18.txt AC 92 ms 30208 KiB
in19.txt AC 1 ms 1664 KiB
in20.txt AC 20 ms 9984 KiB
in21.txt AC 72 ms 29568 KiB
in22.txt AC 134 ms 30208 KiB
in23.txt AC 41 ms 11520 KiB
in24.txt AC 42 ms 10624 KiB
in25.txt AC 89 ms 30464 KiB
in26.txt AC 90 ms 30464 KiB
in27.txt AC 60 ms 11008 KiB
in28.txt AC 166 ms 29440 KiB
in29.txt AC 90 ms 27008 KiB
in30.txt AC 109 ms 27904 KiB
in31.txt AC 1 ms 1664 KiB
in32.txt AC 1 ms 1664 KiB
in33.txt AC 1 ms 1664 KiB
in34.txt AC 1 ms 1920 KiB
in35.txt AC 1 ms 1664 KiB
in36.txt AC 1 ms 1792 KiB
in37.txt AC 1 ms 1664 KiB
in38.txt AC 1 ms 1664 KiB
in39.txt AC 1 ms 1664 KiB
in40.txt AC 1 ms 1792 KiB
in41.txt AC 154 ms 30080 KiB
in42.txt AC 157 ms 29312 KiB
in43.txt AC 1 ms 1664 KiB
in44.txt AC 1 ms 1664 KiB
in45.txt AC 1 ms 1664 KiB
in46.txt AC 1 ms 1664 KiB
in47.txt AC 0 ms 1664 KiB
in48.txt AC 1 ms 1664 KiB
in49.txt AC 1 ms 1664 KiB
in50.txt AC 93 ms 30208 KiB
in51.txt AC 139 ms 30208 KiB
in52.txt AC 132 ms 30208 KiB
in53.txt AC 134 ms 28544 KiB
in54.txt AC 148 ms 29696 KiB
in55.txt AC 148 ms 27904 KiB
in56.txt AC 131 ms 29440 KiB
in57.txt AC 109 ms 30208 KiB
in58.txt AC 162 ms 30080 KiB
in59.txt AC 149 ms 27904 KiB
in60.txt AC 143 ms 30464 KiB
in61.txt AC 154 ms 27008 KiB
in62.txt AC 161 ms 30080 KiB
in63.txt AC 164 ms 29440 KiB
in64.txt AC 157 ms 27008 KiB
in65.txt AC 122 ms 26012 KiB
in66.txt AC 163 ms 29440 KiB
in67.txt AC 111 ms 27392 KiB
in68.txt AC 1 ms 1664 KiB
in69.txt AC 1 ms 1664 KiB
in70.txt AC 1 ms 1664 KiB
in71.txt AC 1 ms 1664 KiB
in72.txt AC 0 ms 1664 KiB
in73.txt AC 0 ms 1664 KiB
in74.txt AC 10 ms 11136 KiB
in75.txt AC 10 ms 11136 KiB
in76.txt AC 1 ms 1664 KiB
in77.txt AC 0 ms 1664 KiB
in78.txt AC 0 ms 1664 KiB
in79.txt AC 1 ms 1664 KiB
in80.txt AC 25 ms 15232 KiB
in81.txt AC 23 ms 14464 KiB
in82.txt AC 1 ms 1664 KiB
in83.txt AC 1 ms 1664 KiB
in84.txt AC 1 ms 1664 KiB
in85.txt AC 1 ms 1664 KiB
in86.txt AC 0 ms 1664 KiB
in87.txt AC 1 ms 1664 KiB
in88.txt AC 115 ms 30464 KiB
in89.txt AC 133 ms 28544 KiB
sample01.txt AC 1 ms 1664 KiB
sample02.txt AC 0 ms 1664 KiB
sample03.txt AC 1 ms 1664 KiB
sample04.txt AC 0 ms 1664 KiB
sample05.txt AC 0 ms 1664 KiB