Official

F - K-th Largest Triplet Editorial by en_translator


Sort each of \(A\), \(B\), and \(C\) in descending order. Also, let \(f(i,j,k)=A_iB_j+B_jC_k+C_kA_i\).

We use binary search for the answer. It is sufficient if we can find the answer to the following decision problem: are there \(k\) or more tuples \((i,j,k)\) with \(f(i,j,k)\geq mid\)?

If we try to naively solve the decision problem, it costs \(O(N^3)\) time, but we can be optimized using pruning. Since

\[f(i,j,k)\geq f(i+1,j,k),f(i,j,k)\geq f(i,j+1,k),f(i,j,k)\geq f(i,j,k+1),\]

one can implement the decision function as follows:

bool check(long long mid) {
    int cnt = 0;
    for (int i = 1; i <= N; i++) {
        if (f(i, 1, 1) < mid) break;  // abort if f(i, 1, 1) is less than mid
        for (int j = 1; j <= N; j++) {
            if (f(i, j, 1) < mid) break;  // abort if f(i, j, 1) is less than mid
            for (int k = 1; k <= N; k++) {
                if (f(i, j, k) < mid) break;  // abort if f(i, j, k) is less than mid
                cnt += 1;
                if (cnt == K) return true;  // return true if cnt is K
            }
        }
    }
    return false  // return false
}

This reduces the number of iterations of the loop in the decision function down to \(O(K)\). Thus, the problem can be solved in a total of \(O(N\log N+K(\log \max A+\log \max B+\log \max C))\) time.

N, K = map(int, input().split())
A = sorted(list(map(int, input().split())), reverse=True)
B = sorted(list(map(int, input().split())), reverse=True)
C = sorted(list(map(int, input().split())), reverse=True)


def f(i, j, k):
    return A[i] * B[j] + B[j] * C[k] + C[k] * A[i]


def calc(mid):
    cnt = 0
    for i in range(N):
        if f(i, 0, 0) < mid:
            break
        for j in range(N):
            if f(i, j, 0) < mid:
                break
            for k in range(N):
                if f(i, j, k) >= mid:
                    cnt += 1
                    if cnt == K:
                        return True
                else:
                    break
    return False


ok, ng = 0, 3 * 10**18 + 1
while ng - ok > 1:
    mid = (ok + ng) // 2
    if calc(mid):
        ok = mid
    else:
        ng = mid

print(ok)

posted:
last update: