E - 魔法の道 / Magic Road 解説 by admin
Claude 4.6 Opus (Thinking)概要
街 \(1\) から街 \(N\) へちょうど \(K\) 本の道を通る全経路について、通った道の魔力値の積の総和を求める問題です。隣接行列の行列累乗により効率的に解くことができます。
考察
重要な気づき:隣接行列と経路の関係
通常のグラフ理論では、隣接行列 \(A\)(辺があれば \(1\)、なければ \(0\))の \(K\) 乗 \(A^K\) の \((i, j)\) 成分は「頂点 \(i\) から頂点 \(j\) へちょうど \(K\) 本の辺を通る経路の 本数」を表すことが知られています。
この問題では辺に重み(魔力値)があり、求めたいのは経路の本数ではなく「各経路の魔力値の積の総和」です。しかし、隣接行列の \((u, v)\) 成分を魔力値 \(W(u, v)\) とすると、行列の掛け算の定義:
\[C[i][j] = \sum_{k} A[i][k] \times B[k][j]\]
がまさに「中間地点 \(k\) を経由するすべてのパターンについて、魔力値の積を足し合わせる」操作に対応します。
具体例で確認しましょう。\(K = 2\) のとき、\(A^2\) の \((1, N)\) 成分は:
\[\sum_{k=1}^{N} W(1, k) \times W(k, N)\]
これは街 \(1\) → 街 \(k\) → 街 \(N\) という長さ \(2\) の全経路について、魔力値の積を合計したものです。\(K\) 回の移動でも同様に、\(A^K\) の \((1, N)\) 成分が求める答えになります。
素朴なアプローチの問題点
\(K\) が最大 \(10^9\) と非常に大きいため、行列を \(K\) 回愚直に掛け算すると \(O(K \times N^3)\) となり、到底間に合いません。
解決策:行列の繰り返し二乗法(行列累乗)
整数の繰り返し二乗法(\(a^K\) を \(O(\log K)\) 回の掛け算で求める手法)を行列に適用します。\(K\) を二進数で表し、必要な部分だけ掛けていくことで、行列の掛け算を \(O(\log K)\) 回に抑えられます。
アルゴリズム
- \(N \times N\) の隣接行列 \(A\) を作る。\(A[u][v] = W(u, v)\)(道がなければ \(0\))。
- 行列の繰り返し二乗法で \(A^K\) を計算する。すべての演算は \(\mod 998244353\) で行う。
- \(A^K\) の \((0, N-1)\) 成分(0-indexed で街 \(1\) から街 \(N\))が答え。
繰り返し二乗法の手順: - 結果行列 \(R\) を単位行列で初期化 - \(K\) の各ビットを下位から見て、ビットが立っていれば \(R \leftarrow R \times A^{2^i}\) を掛ける - 各ステップで \(A^{2^i}\) は前のステップの値を二乗して得る
計算量
- 時間計算量: \(O(N^3 \log K)\)
- 行列の掛け算1回が \(O(N^3)\)、繰り返し二乗法で \(O(\log K)\) 回の掛け算
- \(N \leq 100\), \(K \leq 10^9\) なので、\(100^3 \times 30 \approx 3 \times 10^7\) 程度
- 空間計算量: \(O(N^2)\)
- 行列を数枚保持するだけ
実装のポイント
行列の掛け算で
A[i][k] == 0のときスキップする枝刈りを入れると、疎なグラフで高速化できる。魔力値 \(W_i\) が最大 \(10^9\) なので、入力時点で \(\mod 998244353\) を取っておく。
Python では \(N = 100\) の行列累乗が遅くなりがちだが、上記の枝刈りや
sys.stdin.buffer.read()による高速入力で対応する。経路が存在しない場合は自然に \(A^K[0][N-1] = 0\) となるため、特別な処理は不要。
ソースコード
import sys
def mat_mult(A, B, n, mod):
C = [[0] * n for _ in range(n)]
for i in range(n):
for k in range(n):
if A[i][k] == 0:
continue
aik = A[i][k]
for j in range(n):
C[i][j] = (C[i][j] + aik * B[k][j]) % mod
return C
def mat_pow(M, p, n, mod):
# Initialize result as identity matrix
R = [[0] * n for _ in range(n)]
for i in range(n):
R[i][i] = 1
base = M
while p > 0:
if p & 1:
R = mat_mult(R, base, n, mod)
base = mat_mult(base, base, n, mod)
p >>= 1
return R
def main():
input_data = sys.stdin.buffer.read().split()
idx = 0
N = int(input_data[idx]); idx += 1
M = int(input_data[idx]); idx += 1
K = int(input_data[idx]); idx += 1
MOD = 998244353
# Build adjacency matrix where adj[u][v] = W(u,v)
adj = [[0] * N for _ in range(N)]
for i in range(M):
u = int(input_data[idx]) - 1; idx += 1
v = int(input_data[idx]) - 1; idx += 1
w = int(input_data[idx]) % MOD; idx += 1
adj[u][v] = w
# Matrix exponentiation: adj^K, then answer is result[0][N-1]
result = mat_pow(adj, K, N, MOD)
print(result[0][N - 1] % MOD)
if __name__ == '__main__':
main()
この解説は claude4.6opus-thinking によって生成されました。
投稿日時:
最終更新: