H - Avoid K Partition 解説 by en_translator
For explanation, let us rephrase the problem as follows:
There are points \(1, 2, \dots, N+1\) arranged in this order.
How many ways, modulo \(998244353\), are there to choose some points to put “partitions” on them, so that:
- points \(1\) and \(N+1\) always have a partition, and
- if points \(i\) and \(j\) have a partition but not at any points between them, then \(\sum_{k=i}^{j-1} A_k\) never equals \(K\)?
This new version allows us an \(\mathrm{O}(N^2)\) Dynamic Programming (DP). Specifically, we can define
- \(dp[i]\): the number of ways to choose points from point \(1\) through \(i\), while always choosing point \(i\).
The initial state is \(dp[1] = 1\) and the answer is \(dp[N + 1]\), with the transition being
\[dp[n] = \sum_{1 \leq m \lt n} \left(0\text{ if }\sum_{k=m}^{n-1} A_i = K \text{ else } dp[m] \right).\]
We try to optimize this DP. Define the cumulative sum table of \(A\) as \(B_i = \sum_{j = 1}^i A_j\), then the transition above turns into
\[ \begin{aligned} dp[n] &= \sum_{1 \leq m \lt n} \left(0\text{ if } B_{n-1} - B_{m-1} = K \text{ else } dp[m] \right) \\ &= \sum_{1 \leq m \lt n, B_{m-1} \neq B_{n-1} - K} dp[m]. \end{aligned} \]
In other words, whether \(dp[m]\) contributes to \(dp[n]\) is solely dependent on the value of \(B_{m-1}\).
So we manage the DP table in an associative array (std::map
in C++). Define an associative array \(M\) by
- \(M[x]\) : the sum of \(dp[m]\) over \(m\) with \(B_{m-1} = x\).
Also, manage the sum of \(dp[m]\) in a variable \(all\) also. Then the transition of the DP can be written as
\[dp[n] = all - M[B_{n-1} - K],\]
which can be evaluated fast; \(M\) and \(all\) can also be updated fast too.
By implementing the algorithm above appropriately, the problem can be solved in a total of \(\mathrm{O}(N \log N)\) or \(\mathrm{O}(N)\) time, which is fast enough.
The following is sample code in C++.
#include <iostream>
#include <map>
#include <vector>
using namespace std;
#include "atcoder/modint.hpp"
using mint = atcoder::modint998244353;
int main() {
long long N, K;
cin >> N >> K;
vector<long long> A(N);
for (auto& a : A) cin >> a;
map<long long, mint> M;
M[0] = 1;
mint all = 1;
long long acc = 0;
for (int i = 0; i < N; i++) {
acc += A[i];
long long ban = acc - K;
mint cur = all - M[ban];
M[acc] += cur, all += cur;
if (i + 1 == N) cout << cur.val() << "\n";
}
}
投稿日時:
最終更新: