Official

D - Freefall Editorial by en_translator


If Takahashi performs the operation \(n\ (n \geq 0)\) times, he lands at time \(f(n) = Bn+\frac{A}{\sqrt{n+1}}\). Thus, we need to find the minimum value of \(f(n)\) when \(n\) spans the integers greater than or equal to \(0\).

Hereinafter, we regard \(f\) as a function defined for all real numbers greater than or equal to \(0\).

Let us first estimate the range of the answer. If \(n >= A/B\), we have \(f(n) > Bn \geq A = f(0)\), so \(f(n) > f(0)\); thus, we only have to consider the range \(0 \leq n < A/B\). However, \(A\) and \(B\) in this problem here are so large, so we cannot evaluate \(f(n)\) for all candidates \(n\).

In order to solve this problem, you need to realize that \(f\) is a convex function. The figure below plots \(y=f(x)\) for Sample Input/Output \(1\) (with horizontal \(x\) axis). A function of this shape is called a convex function. (For a technical definition, see https://en.wikipedia.org/wiki/Convex_function.) The concavity of \(f(x)\) can be proved with a theorem that “the sum of concave functions is concave,” or derivative.

入力例1における f(n)

Now we introduce two solutions to this problem.

Solution 1: trinary search

One can find the minimum value of a concave function with an algorithm called trinary search. In trinary search, we define \(l\) and \(r\) that is guaranteed that “\(f(n)\) is minimum at \(n\) such that \(l\leq n\leq r\),” then narrow down the range of \(l\) and \(r\) while preserving (*) in order to find the desired \(n\). The algorithm specifically goes as follows:

  1. Set initial \(l\) and \(r\) that satisfies (*).
  2. Let \(m_1 := (2l+r)/3\) and \(m_2 := (l+2r)/3\). If \(f(m_1) < f(m_2)\), let \(r \leftarrow m_2\); otherwise, let \(l \leftarrow m_1\).
  3. If the range \([l,r]\), which is guaranteed to contain the answer, is narrow enough, terminate the procedure. If you want to narrow down even more, repeat step 2. again.

This time, we want to find the minimum value for integers \(n\), so we let \(l,r,m_1\) and \(m_2\) be integers (if we let them be real values, we can find the minimum value in the real value domain).

In the trinary search, the range \([l,r]\) containing the answer is narrowed down to about \(\frac{2}{3}\), so we can find \(n\) that minimizes \(f(n)\) fast enough. In fact, in this problem we can use the initial values \(l=0,r=A/B\), and performing step 2. about 100 times yields an answer fast enough.

Sample code (C++):

#include<bits/stdc++.h>

using namespace std;
using ll = long long;

int main() {
    ll a, b;
    cin >> a >> b;
    auto f = [&](ll n) -> double {
        return (double) a / sqrt(n + 1) + (double) b * n;
    };
    ll l = 0, r = a / b;
    while (r - l > 2) {
        ll m1 = (l * 2 + r) / 3;
        ll m2 = (l + r * 2) / 3;
        if (f(m1) > f(m2)) l = m1;
        else r = m2;
    }
    double ans = a;
    for (ll i = l; i <= r; i++) {
        ans = min(ans, f(i));
    }
    cout << fixed << setprecision(10) << ans << endl;
}

Solution 2: derivative

Here, we assume the knowledge of high-school math.

Since \(f'(x) = B-\frac{A}{2(x+1)\sqrt{x+1}}\), we have \(f'(x)=0 \Leftrightarrow x = (\frac{A}{2B})^{\frac{2}{3}}-1\), so \(f(x)\) is minimum at \(x = (\frac{A}{2B})^{\frac{2}{3}}-1\) (when \(x\) takes a real value). Although we need to find the minimum value for an integer \(x\) in this problem, it turns out that we only have to inspect \(\lfloor (\frac{A}{2B})^{\frac{2}{3}} \rfloor\) and \(\lceil (\frac{A}{2B})^{\frac{2}{3}} \rceil\) due to the convexity. To handle numerical errors, it is a good idea to inspect the range within \(\pm 5\) just in case.

Sample code (C++):

#include<bits/stdc++.h>

using namespace std;
using ll = long long;

int main() {
    ll a, b;
    cin >> a >> b;
    auto f = [&](ll n) -> double {
        return (double) a / sqrt(n + 1) + (double) b * n;
    };
    ll argmin = pow((double) a / (b * 2), 2. / 3.) - 1;
    ll l = max(argmin - 5, 0LL), r = min(argmin + 5, a / b);
    double ans = a;
    for (ll i = l; i <= r; i++) {
        ans = min(ans, f(i));
    }
    cout << fixed << setprecision(10) << ans << endl;
}

posted:
last update: