G - Road Blocked 2 Editorial by blueberry1001


数え上げによる解法を紹介します。

まず、頂点 \(1\) を始点とするダイクストラ法と頂点 \(N\) を始点とするダイクストラ法を行い、各頂点について

  • 頂点 \(1\) からその頂点までの最短距離
  • 頂点 \(1\) から最短距離でその頂点にたどり着くような経路の数
  • 頂点 \(N\) からその頂点までの最短距離
  • 頂点 \(N\) から最短距離でその頂点にたどり着くような経路の数

を求めます。ここで、

  • ある辺 \(E\) を削除したときに頂点 \(1\) から\(N\)までの最短距離が変化する

は、

  • 頂点 \(1\) から頂点 \(N\) までの最短経路であって辺\(E\)を通るものが存在する
  • (頂点 \(1\) から頂点 \(N\) へ最短距離でたどり着くような経路の数)と、(頂点 \(1\) から頂点 \(N\) へ最短距離でたどり着くような経路のうち辺 \(E\) を通る経路の数)が等しい

の両方を満たすことと同値です。

「頂点 \(1\) から頂点 \(N\) へ最短距離でたどり着くような経路のうち辺 \(E\) を通る経路の数」は、最初に求めた4つの値から計算できます。

経路数はとても大きくなる可能性があるため、適宜modをとって対応するとよいでしょう。一般的に衝突が心配な場合は複数modにするとよいですが、この問題では単一modでもACすることはできました。

実装例(C++)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

#include<atcoder/modint>
using mint = atcoder::modint998244353;
const ll INF = ((1LL<<62)-(1LL<<31));

int main(){
	int n,m;cin >> n >> m;
	vector<vector<pair<ll,ll>>>g(n);
	vector<tuple<ll,ll,ll>>edges;
	for(int i=0;i<m;i++){
	  ll u,v,w;cin >> u >> v >> w;
		u--,v--;
		g[u].push_back({v,w});
		g[v].push_back({u,w});
		edges.push_back({u,v,w});
	}
	vector<ll>dist1(n,INF),dist2(n,INF);
	dist1[0] = 0;
	dist2[n-1] = 0;
	vector<mint>cnt1(n),cnt2(n);
	cnt1[0] = 1;
	cnt2[n-1] = 1;
	priority_queue<pair<ll,ll>,vector<pair<ll,ll>>,greater<pair<ll,ll>>>pq;
	pq.push({0,0});
	while(pq.size()){
		auto[d,pos] = pq.top();pq.pop();
		if(d>dist1[pos])continue;
		for(auto [to,cost]:g[pos]){
			if(dist1[to]>d+cost){
			  dist1[to] = d+cost;
				cnt1[to] = cnt1[pos];
				pq.push({d+cost,to});
			}
			else if(dist1[to]==d+cost){
				cnt1[to] += cnt1[pos];
			}
		}
	}

	pq.push({0,n-1});
	while(pq.size()){
		auto[d,pos] = pq.top();pq.pop();
		if(d>dist2[pos])continue;
		for(auto [to,cost]:g[pos]){
			if(dist2[to]>d+cost){
			  dist2[to] = d+cost;
				cnt2[to] = cnt2[pos];
				pq.push({d+cost,to});
			}
			else if(dist2[to]==d+cost){
				cnt2[to] += cnt2[pos];
			}
		}
	}
	for(auto[l,r,w]:edges){
		//l->rに
		if(dist1[r]+dist2[l]<dist1[l]+dist2[r])swap(l,r);
		//そもそもこの辺を使うともとの最短距離より悪化する場合、明らかに最短距離は変化しない
		if(dist1[l]+dist2[r]+w!=dist2[0]){
			cout << "No" << endl;
			
			continue;
		}
		if((cnt1[l]*cnt2[r]).val()==cnt1[n-1].val()){
			cout << "Yes" << endl;
		}
		else{
			cout << "No" << endl;
		}
	}
	return 0;
}

提出コード(C++)

posted:
last update: