Official

F - Various Kagamimochi Editorial by en_translator


Let us fix the larger of a pair of mochi’s that forms a kagamimochi. In other words, we take a mochi and consider the pairs that it goes to the bottom.

The number of pairs where a mochi of size \(a\) goes to the bottom equal to the number of mochi’s with size less than or equal to \(\dfrac a2\). Thus, the problem can be solved by, for each mochi, counting the number of those whose sizes are not greater than the half, and summing up the counts.

However, for the \(i\)-th mochi, there can be at most about \(\Theta(i)\) mochi’s whose sizes are not greater than the half, so inspecting them one by one does not finish within the execution time limit. (For example, setting \(A _ i=i\) we can make the count about \(\dfrac i2\). We can prepare a case with even more of them.)

Here, we can solve the problem fast enough by the sliding window technique or binary search. More specifically, since the mochi’s are sorted in ascending order, we can count the number of mochi’s with sizes less than or equal to \(x\) by finding the position of the partition point whether the size is greater than \(x\) or not.

The sliding window technique yields a solution with the total time complexity \(\Theta(N)\), and the binary search with \(\Theta(N\log N)\).

Note that the answer may not fit in the \(32\)-bit integer type.

The following is sample code.

(C++, binary search)

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

int main() {
    int N;
    cin >> N;

    vector<int> A(N);
    for (auto&& a : A)
        cin >> a;

    long ans = 0;
    for (const auto a : A)
        // The number of mochi's with size not greater than (a/2) = the distance between the mochi whose size exceeds (a/2) and the first mochi
        ans += ranges::upper_bound(A, a / 2) - begin(A);

    // The sum is the answer
    cout << ans << endl;

    return 0;
}

(Python, binary search)

from bisect import bisect


N = int(input())
A = list(map(int, input().split()))

ans = 0

for a in A:
    # Find the sum of the number of elements in A whose sizes are not greater than (a/2)
    ans += bisect(A, a // 2)

print(ans)

(C++, sliding window)

#include <iostream>
#include <vector>

using namespace std;

int main() {
    int N;
    cin >> N;

    vector<int> A(N);
    for (auto&& a : A)
        cin >> a;

    long ans = 0;
    // Variable j stores the first element greater than (a/2) (or the next element of the last)
    for (int j = 0; const auto a : A) {
        // Advance until it exceeds
        while (j < N && A[j] * 2 <= a) j++;
        ans += j;
    }

    cout << ans << endl;

    return 0;
}

(Python, sliding window)

N = int(input())
A = list(map(int, input().split()))

ans = 0

# Variable j stores the first element greater than (a/2) (or the next element of the last)
j = 0

for a in A:
    # Advance until it exceeds
    while j < N and A[j] * 2 <= a:
        j += 1
    ans += j

print(ans)

posted:
last update: