Official

C - Many Segments Editorial by en_translator


Since the constraints is as small as \(N \leq 2000\), one can iterate every pair of integers \((i, j)\) such that \(1 \leq i \lt j \leq N\) in an \(O(N^2)\) loop, and check in each iteration whether Interval \(i\) and Interval \(j\) intersect in an \(O(1)\) time. Hereinafter, we only describe how to determine if two intervals intersect in an \(O(1)\) time.

There are many ways to check if two intervals intersect. For instance, one of the easiest way to come up with is probably to divide into 16 cases depending on \((t_i,t_j)\).

However, such an implementation takes long time and more likely to put bugs into the program. Actually this can be avoided by transforming each interval into a closed one.

Due to the constraints that the both points of each interval are integer, the following transformation does not change the answer:

  • \([l_i,r_i]\) to \([l_i,r_i]\)
  • \([l_i,r_i)\) to \([l_i,r_i-0.5]\)
  • \((l_i,r_i]\) to \([l_i+0.5,r_i]\)
  • \((l_i,r_i)\) to \([l_i+0.5,r_i-0.5]\)

By transforming the intervals in such a manner, one can check if two intervals intersect without dividing into cases, which drastically simplifies the implementation.

Note that two intervals \([a,b]\) and \([c,d]\) intersects if and only if \(\max(a,c) \leq \min(b,d)\).

Sample code (Python)

N = int(input())
l = [0]*N
r = [0]*N
for i in range(N):
    t,l[i],r[i] = map(int,input().split())
    if t == 2:
        r[i] -= 0.5
    elif t == 3:
        l[i] += 0.5
    elif t == 4:
        l[i] += 0.5
        r[i] -= 0.5
ans = 0
for i in range(N):
    for j in range(i+1,N):
        ans += (max(l[i],l[j]) <= min(r[i],r[j]))
print(ans)

One can exploit the bit operations to simplify the code even more.

Sample code (C++)

#include<bits/stdc++.h>
using namespace std;

int main(){
    int N; cin >> N;
    vector<double> l(N),r(N);
    for(int i=0; i<N; i++){
        int t; cin >> t >> l[i] >> r[i];
        t--;
        if(t&1) r[i] -= 0.5;
        if(t&2) l[i] += 0.5;
    }
    int ans = 0;
    for(int i=0; i<N; i++){
        for(int j=i+1; j<N; j++){
            ans += (max(l[i],l[j]) <= min(r[i],r[j]));
        }
    }
    cout << ans << endl;
}

BONUS: \(N \leq 2 \times 10^5\)

posted:
last update: