D - AtCoder Express Editorial by sounansya


\(\displaystyle M=\sum_{i=1}^N t_i\) と定義します。

\(d[i]=\{\) 時刻 \(i+0.5\) に出すことが許されているスピード \(\}\) と定義します。これは \(t,v\) から簡単に求めることができます。

\(a[i]=\{i \) 秒後に出すことができるスピードの最大値 \(\}\) と定義します。これは次のようにすることで求めることができます。

  • \(a\) の要素を 十分大きい値で満たす。
  • \(i=1,2,\ldots,N\) に対し、 \(\displaystyle \sum_{j=1}^{i-1}t_j\le k\le \sum_{j=1}^it_j\) を満たす全ての \(k\) に対し \(a[k]\)\(\min(a[k],v[i])\) で置き換える
  • \(a[0]=a[M]=0\) とする
  • \(i=1,2,\ldots,M\) に対し、 \(a[i]\)\(\min(a[i] , a[i-1]+1)\) で置き換える
  • \(i=M-1,M-2,\ldots,0\) に対し、 \(a[i]\)\(\min(a[i] , a[i+1]+1)\) で置き換える

\(1,2\) つ目で加速度の制限がない場合のスピードの最大値を求め、 \(3\) つ目で走行開始時と走行終了時には列車は止まっていなければならないという条件を追加する。そして最後に \(3,4\) つ目で加速度の制限を課す、と考えると分かりやすいです。

次に、列車が発車してから停車するまでに走れる最大の距離を求めます。

この値は \(i=0,1,\ldots,M-1\) に対して時刻 \(i\) から時刻 \(i+1\) の間で進める距離の最大値の和と言い換えることができます。これを各 \(i\) に対して求めます。

まず、 \(a[i]\neq a[i+1]\) の場合は \(\displaystyle\frac{a[i]+a[i+1]}2\) が答えです。これはもし \(a[i] < a[i+1]\) なら加速度 \(+1\) での移動が、 \(a[i] > a[i+1]\) なら加速度 \(-1\) での移動が最適だからです。

次に、 \(a[i]=a[i+1]\) である場合を考えます。

もし \(a[i]=d[i]\) なら、 \(a[i]\) が答えです。これは現在出せる最大のスピードを出しており、これ以上スピードを上げることができないからです。

次に \(a[i]\neq d[i]\) の場合ですが、これは \(\displaystyle a[i]+0.25\) が答えです。これは時刻 \(i\) から時刻 \(\displaystyle i+0.5\) まで加速度 \(+1\) で移動し、時刻 \(\displaystyle i+0.5\) から時刻 \(i+1\) まで加速度 \(-1\) で移動することで達成することができます。この移動が最適なのは明らかです。

以上をまとめることで、この問題を \(O(M)\) で解くことができます。

\(\displaystyle M=\sum_{i=1}^N t_i\le \sum_{i=1}^{100}200=2\times10^4\) なので、このアルゴリズムは十分高速に動作します。

なお、下の実装では出来る限り整数型で処理するために各移動の最大値を \(4\) 倍し、最後に \(4\) で割っています。

実装例 (Java)

import java.util.*;
class Main {
    public static void main(String[] args) {
      	var sc = new Scanner(System.in);
        int n = sc.nextInt();
        var t = new int[n];
        Arrays.setAll(t, i -> sc.nextInt());
        var v = new int[n];
        Arrays.setAll(v, i -> sc.nextInt());
        int m = 0;
        for (int i : t)
            m += i;
        var a = new int[m + 1];
        var d = new int[m];
        final int inf = (int) 1e9;
        Arrays.fill(a, inf);
        int c = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j <= t[i]; j++)
                a[c + j] = Math.min(a[c + j], v[i]);
            for (int j = 0; j < t[i]; j++)
                d[c + j] = v[i];
            c += t[i];
        }
        a[0] = a[m] = 0;
        for (int i = 1; i <= m; i++)
            a[i] = Math.min(a[i], a[i - 1] + 1);
        for (int i = m - 1; i >= 0; i--)
            a[i] = Math.min(a[i], a[i + 1] + 1);
        long ans = 0;
        for (int i = 0; i < m; i++) {
            if (a[i] != a[i + 1])
                ans += 2 * (a[i] + a[i + 1]);
            else if (a[i] == d[i])
                ans += 4 * a[i];
            else
                ans += 4 * a[i] + 1;
        }
        System.out.println(ans / 4.);
    }
}

posted:
last update: