G - Ax + By < C 解説
by
sounansya
\(N\) 本の直線 \(A_ix+B_iy=C_i\) \((1\le i\le N)\) を考え、「どの直線よりも下にある \(x,y\) 座標が共に正である格子点の個数」と言い換えます。また、簡単のため同じ傾きの直線は最も下にあるもののみ存在するとし、同じ傾きの直線が存在しない状況を考えます。
ここで、ある \(1\) 本の直線に注目すると、その直線が \(N\) 本の中で最も下にくるような \(x\) 座標は区間をなすことが分かります。それぞれの直線がどの区間で最小となるかは凸包の双対問題となり、凸包を求めるアルゴリズムと同様に次のようなステップで求めることができます(参考:Mongeの手引書 (tatyamさん)):
- 直線を傾き \(\displaystyle \left(=\frac{A_i}{B_i}\right)\) の小さい順にソートする。
- 直線を管理するための空のリスト \(L\) を用意する。
- \(i=1,2,\ldots,N\) の順に以下の操作を行う:
- \(i\) 番目に傾きの小さい直線を \(L\) に追加する。
- 「\(1\) つ前の直線が不要なら削除する」という操作を不要な直線がある間繰り返す。
- 最終的に \(L\) に残った直線が最小値になりうる直線である。
この \(L\) に残った直線の個数を \(N\) と置き直し、 \(L\) の \(i\) 番目の直線を \(A_ix+B_iy=C_i\) と置き直します。
また、 \(i\) 番目の直線と \(i+1\) 番目の直線の交点の \(x\) 座標の小数部分を切り上げた値を \(x_i\) \((1\le i\le N-1)\) とし、 \(x_0=-\infty\) 、 \(x_N=\infty\) とします。
この時、 \(\displaystyle X_{\mathrm{max}}=\min_{1\le i\le N} \left\lceil \frac{C_i}{A_i} \right\rceil\) とすると、それぞれの直線の貢献度を考えることによりこの問題の答えは
\[ \displaystyle \sum_{i=1}^N \sum_{x\in [x_{i-1}, x_i)\cap [1, X_{\mathrm{max}})}\left\lfloor \frac{C_i-1-A_ix}{B_i} \right\rfloor \]
になることが分かります。
\(\displaystyle \sum_{x\in [x_{i-1}, x_i)\cap [1, X_{\mathrm{max}})}\left\lfloor \frac{C_i-1-A_ix}{B_i} \right\rfloor\) は floor sum アルゴリズムを用いることで \(O(\log A_i)\) で計算することができます。よって、この問題の答えは各テストケース \(O(N(\log N+\log \max_i A_i))\) で計算することができます。
from functools import cmp_to_key
import sys
input = sys.stdin.readline
def floor_sum(n, m, a, b):
a1, a2 = a // m, a % m
ans = n * (n - 1) // 2 * a1
if a2 == 0:
return ans + b // m * n
b1, b2 = b // m, b % m
y = (a2 * n + b2) // m
ans += n * (y + b1) - floor_sum(y, a2, m, m + a2 - b2 - 1)
return ans
def cmp(d1, d2):
a1, b1, c1 = d1
a2, b2, c2 = d2
if a1 * b2 != a2 * b1:
return a1 * b2 - a2 * b1
return c1 * a2 - c2 * a1
def clamp(x, l, r):
return min(r, max(l, x))
for _ in range(int(input())):
N = int(input())
L = [tuple(map(int, input().split())) for _ in range(N)]
L.sort(key=cmp_to_key(cmp))
co = []
x = []
INF = 10**18
x_lim = INF
for l in L:
aj, bj, cj = l
if len(co) >= 1:
ai, bi, ci = co[-1]
if ai * bj == aj * bi:
continue
while len(co) >= 2:
ai, bi, ci = co[-1]
xj = (bi * cj - bj * ci - 1) // (aj * bi - ai * bj) + 1
if xj > x[-1]:
break
co.pop()
x.pop()
co.append(l)
if len(x) == 0:
x.append(-INF)
else:
ai, bi, ci = co[-2]
xj = (bi * cj - bj * ci - 1) // (aj * bi - ai * bj) + 1
x.append(xj)
x_lim = min(x_lim, (cj + aj - 1) // aj)
x.append(INF)
ans = 0
for i in range(len(co)):
ai, bi, ci = co[i]
le, ri = clamp(x[i], 1, x_lim), clamp(x[i + 1], 1, x_lim)
n, m, a, b = ri - le, bi, ai, ci - 1 - ai * (ri - 1)
ans += floor_sum(n, m, a, b % m) + (b // m) * n
print(ans)
投稿日時:
最終更新: