Official

G - Haunted House Editorial by en_translator


Observations

Consider a graph where the connected components within a floor are considered as vertices, and with an edge between passable regions of adjacent floors.

We compute the following values in descending order of the floor numbers:

  • \(dp_0[f][v]\): the maximum collectable coins when starting from vertex \(v\) on floor \(f\), and using a ladder on floor \((f+1)\) or below
  • \(dp_1[f][v]\): the maximum collectable coins when starting from vertex \(v\) on floor \(v\) without using a ladder
  • \(dv_0[f][v][u]\): the maximum collectable coins when starting from vertex \(v\) on floor \(f\), and ascending from \(v\) to vertex \(u\) on floor \((f+1)\)
  • \(dv_1[f+1][u][v]\): the maximum collectable coins when starting from vertex \(u\) on floor \((f+1)\), and descending to vertex \(v\) on floor \(f\)

Let \(c_v\) be the total number of coins in vertex \(v\).

\(dp_0[f][v]\) and \(dv_0[f][v][u]\) can be computed by inspecting \(dp_0[f+1][u]\) and \(dp_1[f+1][u]\) for all adjacent vertices \(u\) on floor \(f+1\).

To find \(dv_1[f+1][u][v]\) based on \(dv_0[f][v][w]\), we take the maximum value for a vertex \(w\) adjacent to \(v\):

  • \(dv_0[f][v][w]\) if \(u = w\),
  • \(dv_0[f][v][w] + c_u\) if \(u \neq w\).

Here, the maximum \(dv_0[f][v][w]\) for \(u \neq w\) is either the first or second largest value of \(dv_0[f][v][w]\). Thus, it is sufficient to find the two largest values among \(dv_0[f][v][w]\), and refer to them when processing the transitions.

Similarly, \(dp_1[v]\) can be computed based on \(dv_1[f+1][u][v]\) and \(dp_1[f+1][u]\).

Sample code (C++)

#include <iostream>
using std::cin;
using std::cout;
#include <vector>
using std::vector;
using std::pair;
#include <algorithm>
using std::sort;
using std::min;
using std::max;
#include <string>
using std::string;
#include <set>
using std::set;

#include <atcoder/dsu>
using atcoder::dsu;

typedef long long int ll;
typedef pair<ll, ll> P;

const ll MOD = 998244353;

ll f, h, w;
vector<vector<string> > s;
vector<vector<vector<ll> > > vs;
ll q;
vector<ll> g, x, y;

void chmax (ll &l, ll r) {
	if (l < r) l = r;
}
void chmin (ll &l, ll r) {
	if (l > r) l = r;
}

void solve () {
	vs.resize(f);
	for (ll k = 0; k < f; k++) {
		vs[k].resize(h);
		for (ll i = 0; i < h; i++) {
			vs[k][i].resize(w);
			for (ll j = 0; j < w; j++) {
				if (s[k][i][j] == '#') {
					vs[k][i][j] = -1;
				} else {
					vs[k][i][j] = (s[k][i][j] - '0');
				}
			}
		}
	}

	vector<dsu> ds;
	for (ll k = 0; k < f; k++) {
		dsu d(h*w);
		for (ll i = 0; i < h; i++) {
			for (ll j = 0; j < w; j++) {
				if (vs[k][i][j] == -1) continue;

				if (i-1 >= 0 && vs[k][i-1][j] != -1) {
					d.merge((i-1)*w+j, i*w+j);
				}
				if (j-1 >= 0 && vs[k][i][j-1] != -1) {
					d.merge(i*w+(j-1), i*w+j);
				}
			}
		}
		ds.push_back(d);
	}

	vector<ll> isroot[f];
	// vx[k][root]: sum of component[k][root]
	vector<ll> vx[f];
	// vgu[k][root]: edge [k]->[k+1], component-wise (unique)
	// vgd[k][root]: edge [k]->[k-1], component-wise (unique)
	vector<ll> vgu[f][h*w], vgd[f][h*w];
	for (ll k = 0; k < f; k++) {
		isroot[k].assign(h*w, 0);
		vx[k].assign(h*w, 0);

		vector<set<ll> > tu(h*w), td(h*w);
		for (ll i = 0; i < h; i++) {
			for (ll j = 0; j < w; j++) {
				if (vs[k][i][j] != -1) {
					ll idx = ds[k].leader(i*w+j);
					isroot[k][idx] = 1;
					vx[k][idx] += vs[k][i][j];

					if (k-1 >= 0 && vs[k-1][i][j] != -1) {
						ll jdx = ds[k-1].leader(i*w+j);
						td[idx].insert(jdx);
					}
					if (k+1 <  f && vs[k+1][i][j] != -1) {
						ll jdx = ds[k+1].leader(i*w+j);
						tu[idx].insert(jdx);
					}
				}
			}
		}

		for (ll i = 0; i < h*w; i++) {
			for (ll v : td[i]) vgd[k][i].push_back(v);
			for (ll v : tu[i]) vgu[k][i].push_back(v);
		}
	}

	ll dp0[f][h*w], dp1[f][h*w];
	ll up1[f][h*w];
	vector<P> dv0[f][h*w], dv1[f][h*w];
	for (ll k = f-1; k >= 0; k--) {
		for (ll v = 0; v < h*w; v++) {
			dp0[k][v] = 0;
			dp1[k][v] = 0;
			up1[k][v] = 0;

			dv0[k][v].clear();
			dv1[k][v].clear();
		}

		// simple
		for (ll v = 0; v < h*w; v++) {
			if (!isroot[k][v]) continue;

			ll origin = vx[k][v];
			ll drop0 = 0;
			ll drop1 = 0;
			if (k+1 < f) {
				for (ll u : vgu[k][v]) {
					chmax(drop0, dp0[k+1][u]);
					chmax(drop1, dp1[k+1][u]);
				}
			}

			dp0[k][v] = origin + drop0;
			dp1[k][v] = origin + drop1;
		}
		// simple end

		// zigzag
		if (k+1 < f) {
			// [k][0]->[k+1][0]
			for (ll v = 0; v < h*w; v++) {
				if (!isroot[k][v]) continue;

				ll curr = vx[k][v];
				for (ll u : vgu[k][v]) {
					ll sum = curr + dp0[k+1][u];
					dv0[k][v].push_back({sum, u});
				}
				sort(   dv0[k][v].begin(), dv0[k][v].end()); // LOG
				reverse(dv0[k][v].begin(), dv0[k][v].end());
			}
			// [k+1][1]->[k][0]
			for (ll v = 0; v < h*w; v++) {
				if (!isroot[k+1][v]) continue;

				ll curr = vx[k+1][v];
				for (ll u : vgd[k+1][v]) {
					ll sum = 0;
					for (ll ui = 0; ui < min(2LL, (ll)dv0[k][u].size()); ui++) {
						ll add = dv0[k][u][ui].first;
						ll idx = dv0[k][u][ui].second;
						if (idx == v) {
							chmax(sum, add);
						} else {
							chmax(sum, curr + add);
						}
					}
					chmax(up1[k+1][v], sum);
					dv1[k+1][v].push_back({sum, u});
				}
				sort(   dv1[k+1][v].begin(), dv1[k+1][v].end()); // LOG
				reverse(dv1[k+1][v].begin(), dv1[k+1][v].end());
			}
			// [k][1]->[k+1][1]
			for (ll v = 0; v < h*w; v++) {
				if (!isroot[k][v]) continue;

				ll curr = vx[k][v];
				for (ll u : vgu[k][v]) {
					ll sum = 0;
					for (ll ui = 0; ui < min(2LL, (ll)dv1[k+1][u].size()); ui++) {
						ll add = dv1[k+1][u][ui].first;
						ll idx = dv1[k+1][u][ui].second;
						if (idx == v) {
							chmax(sum, add);
						} else {
							chmax(sum, curr + add);
						}
					}
					
					chmax(dp1[k][v], sum);
				}
			}
		}
		// zigzag end
	}

	for (ll qi = 0; qi < q; qi++) {
		ll ans = 0;
		ll idx = ds[g[qi]].leader(x[qi] * w + y[qi]);
		chmax(ans, dp0[g[qi]][idx]);
		chmax(ans, dp1[g[qi]][idx]);
		chmax(ans, up1[g[qi]][idx]);

		cout << ans << "\n";
	}

	return;
}

int main (void) {
	std::cin.tie(nullptr);
	std::ios_base::sync_with_stdio(false);

	cin >> f >> h >> w;
	s.resize(f);
	for (ll k = 0; k < f; k++) {
		s[k].resize(h);
		for (ll i = 0; i < h; i++) {
			cin >> s[k][i];
		}
	}

	cin >> q;
	g.resize(q);
	x.resize(q);
	y.resize(q);
	for (ll i = 0; i < q; i++) {
		cin >> g[i] >> x[i] >> y[i];
		g[i]--;
		x[i]--;
		y[i]--;
	}

	solve();

	
	return 0;
}

posted:
last update: