公式

D - 都市間の距離 / Distance Between Cities 解説 by admin

Qwen3-Coder-480B

概要

\(N\) 個の \(M\) 次元座標を持つ都市間のマンハッタン距離の総和を効率的に求める問題です。

考察

マンハッタン距離は各次元ごとの差の絶対値の和で表されます:

\[ \sum_{1 \leq i < j \leq N} \sum_{k=1}^{M} |A_{i,k} - A_{j,k}| \]

素朴な解法では、全ての都市の組 \((i, j)\) について、各次元での差分の絶対値を計算する必要があります。これは \(O(N^2 M)\) の計算量となり、\(N\) が最大 \(2 \times 10^5\) の制約では到底間に合いません。

しかし、マンハッタン距離の性質を利用することで、計算を高速化できます。各次元 \(k\) において、\(|A_{i,k} - A_{j,k}|\) の総和を独立に求めることができます。つまり、各次元ごとに処理すれば良いのです。

さらに、各次元について座標をソートすることで、各座標値が他のどの座標値と比較されるかを効率的に求められます。具体的には、ある値 \(x\) がソートされた配列の \(i\) 番目にあったとき:

  • \(x\) が「大きい側」として扱われるのは、自分より左にある要素(\(i\) 個)との比較において
  • \(x\) が「小さい側」として扱われるのは、自分より右にある要素(\(N - 1 - i\) 個)との比較において

したがって、その値 \(x\) の寄与は:

\[ x \times (i - (N - 1 - i)) = x \times (2i - N + 1) \]

と計算できます。これにより、各次元について \(O(N \log N)\) で計算可能になります。

アルゴリズム

  1. 各次元 \(k = 1, \dots, M\) について以下を行う:

    • すべての都市の \(k\) 次元目の座標を抜き出してリスト col を作る
    • col をソートする
    • ソート後の各要素 col[i] について、それが他の要素と比較される際に与える寄与を計算:
      • 左側(自分より小さい値)との比較回数:\(i\)
      • 右側(自分より大きい値)との比較回数:\(N - 1 - i\)
      • 寄与:col[i] * (i - (N - 1 - i))
    • この寄与を合計に加算する
  2. 各次元の寄与の合計が答えとなる。

計算量

  • 時間計算量: \(O(NM \log N)\)
    (各次元ごとにソートが必要 → \(M\) 回のソート、それぞれ \(O(N \log N)\)
  • 空間計算量: \(O(NM)\)
    (すべての座標を保持)

実装のポイント

  • 座標の読み込みを高速に行うために sys.stdin.read を使用している

  • 各次元ごとに独立に計算できることに注目し、多次元を一次元ずつ処理する

  • ソート後に「各要素がどれだけ寄与するか」を係数で計算するのがポイント

    ソースコード

import sys
input = sys.stdin.read

def main():
    data = input().split()
    N = int(data[0])
    M = int(data[1])
    
    # 座標データの読み込み
    coords = []
    idx = 2
    for i in range(N):
        row = list(map(int, data[idx:idx+M]))
        coords.append(row)
        idx += M
    
    total = 0
    
    # 各次元ごとに処理
    for k in range(M):
        # k次元目の座標を取り出す
        col = [coords[i][k] for i in range(N)]
        # ソートする
        col.sort()
        # 各要素が他の要素に対してどれだけ寄与するかを計算
        for i in range(N):
            # col[i] が引かれる回数: (N - 1 - i)
            # col[i] が足される回数: i
            total += col[i] * (i - (N - 1 - i))
    
    print(total)

if __name__ == "__main__":
    main()

この解説は qwen3-coder-480b によって生成されました。

投稿日時:
最終更新: