公式

D - Sum of Differences 解説 by sheyasutaka


考察

\(A\) の要素を昇順に並べ替えて \(A_1 \leq A_2 \leq \dots \leq A_N\) としても答えは変化しません。以下では \(A\) は昇順であるとします。

\(|x - y|\) の値は、\(x \geq y\) のときは \(x-y\) であり、\(x < y\) のときは \(y-x\) です。この場合分けから、\(A_i \geq B_j\) の場合と \(A_i < B_j\) の場合のそれぞれについて分けて考える方針が立ちます。

ある \(B_j\) を固定して \(y\) とおきます。\(\sum_{i = 1}^{N} |A_i - y|\) の値は、以下の \(2\) つの値の和として表せます。

  • \(A_i \geq y\) となる \(A_i\) に対する \((A_i - y)\) の和。
  • \(A_i < y\) となる \(A_i\) に対する \((y - A_i)\) の和。

前者の値の求め方

ここで、\(A_i \geq y\) となる最小の \(i\) (存在しなければ \(N+1\)) を \(i_y\) とおくと、 \(A_i \geq y\) となる \(A_i\) の個数は \(N+1-i_y\) となることから、前者の値は以下のように式変形できます。

\[\sum_{A_i \geq y} (A_i - y) = \sum_{A_i \geq y} A_i - \sum_{A_i \geq y} y = (A_{i_y+1} + \dots + A_N) - (N + 1 - i_y)y\]

したがって、\(y\) に対する \(i_y\) の値と、\((A_{i_y+1} + \dots + A_N)\) の値をそれぞれ高速に求めれば、上の値を高速に計算できます。

\(A_i \geq y\) となる最小の \(i\) は、\(A\) が昇順に並んでいるので二分探索で求めることができます。

\((A_{i_y + 1} + \dots + A_N)\) の値は、それぞれの \(i_y \in \{1, \dots, N+1\}\) について降順に求めておく前計算を事前にすることで、\(y\) ごとに毎回独立に計算する必要がなくなります。

以上の高速化によって、求める値を高速に計算できます。

後者の値の求め方

\(A_i \geq y\) となる最小の \(i\) (存在しなければ \(N+1\)) を \(i_y\) とおくと、 \(A_i < y\) となる \(A_i\) の個数は \(i_y-1\) となることから、後者の値は以下のように式変形できます。

\[\sum_{A_i < y} (y - A_i) = \sum_{A_i < y} y - \sum_{A_i < y} A_i = (i_y - 1)y - (A_1 + \dots + A_{i_y - 1})\]

これは前者の値でおこなったやり方と同様にして高速に求まります。


全体の実装方針

ここまでの考察から、以下の実装を行えば問題を解くことができます。

  • \(A\) を昇順に並べ替える。
  • \(i \in \{0, 1, \dots, N+1\}\) に対して、\((A_1 + \dots + A_i)\) を累積和によって計算する。
  • \(i \in \{0, 1, \dots, N+1\}\) に対して、\((A_i + \dots + A_N)\) を累積和によって計算する。
  • 変数 \(ans\)\(0\) で初期化する。
  • \(j = 1, \dots, M\) に対して、\(y = B_j\) として以下の計算を行う。
    • \(A_i \geq y\) となる最小の \(i\) (存在しなければ \(N+1\)) を二分探索によって求める(その値を \(i_y\) とする)。
    • \((A_{i_y+1} + \dots + A_N) - (N + 1 - i_y)y\)\(ans\) に加える。
    • \((i_y - 1)y - (A_1 + \dots + A_{i_y - 1})y\)\(ans\) に加える。
  • \(ans \bmod 998244353\) の値を出力する。

時間計算量は、ソートおよび二分探索がボトルネックになって \(O((N+M) \log (N+M))\) となり、十分高速です。

適宜 mod をとる必要があることに注意して下さい。特に 32 bit 整数や 64 bit 整数などの固定ビット長変数を使う場合は、オーバーフローを避けるため、最後の出力直前以外でも適切に mod をとる必要があることに注意してください。

実装例 (C++)

#include <iostream>
using std::cin;
using std::cout;
#include <vector>
using std::vector;
#include <algorithm>
using std::sort;

typedef long long int ll;

const ll MOD = 998244353;

ll n, m;
vector<ll> a, b;


void solve () {
	sort(a.begin(), a.end());
	sort(b.begin(), b.end());

	// sum[0, i), sum[i, n)
	vector<ll> apre(n+1, 0LL), asuf(n+1, 0LL);
	apre[0] = 0;
	for (ll i = 0; i < n; i++) {
		apre[i+1] = (apre[i] + a[i]);
	}
	asuf[n] = 0;
	for (ll i = n-1; i >= 0; i--) {
		asuf[i] = (asuf[i+1] + a[i]);
	}

	ll ans = 0;
	for (ll i = 0; i < m; i++) {
		// a[<x] < b[i], a[>=x] >= b[i]
		ll ok = n, ng = -1;
		while (ng + 1 < ok) {
			ll med = (ok + ng) / 2;
			if (a[med] >= b[i]) {
				ok = med;
			} else {
				ng = med;
			}
		}
		const ll x = ok;

		// prefix sum
		ans += ((b[i] * x) - apre[x]) % MOD;
		// suffix sum
		ans += (asuf[x] - (b[i] * (n-x))) % MOD;
	}
	ans %= MOD;

	cout << ans << "\n";
	return;	
}

int main (void) {
	cin >> n >> m;
	a.resize(n);
	b.resize(m);
	for (ll i = 0; i < n; i++) {
		cin >> a[i];
	}
	for (ll i = 0; i < m; i++) {
		cin >> b[i];
	}

	solve();

	
	return 0;
}

投稿日時:
最終更新: