Official

H - Avoid K Partition Editorial by en_translator


For explanation, let us rephrase the problem as follows:

There are points \(1, 2, \dots, N+1\) arranged in this order.
How many ways, modulo \(998244353\), are there to choose some points to put “partitions” on them, so that:

  • points \(1\) and \(N+1\) always have a partition, and
  • if points \(i\) and \(j\) have a partition but not at any points between them, then \(\sum_{k=i}^{j-1} A_k\) never equals \(K\)?

This new version allows us an \(\mathrm{O}(N^2)\) Dynamic Programming (DP). Specifically, we can define

  • \(dp[i]\): the number of ways to choose points from point \(1\) through \(i\), while always choosing point \(i\).

The initial state is \(dp[1] = 1\) and the answer is \(dp[N + 1]\), with the transition being

\[dp[n] = \sum_{1 \leq m \lt n} \left(0\text{ if }\sum_{k=m}^{n-1} A_i = K \text{ else } dp[m] \right).\]

We try to optimize this DP. Define the cumulative sum table of \(A\) as \(B_i = \sum_{j = 1}^i A_j\), then the transition above turns into

\[ \begin{aligned} dp[n] &= \sum_{1 \leq m \lt n} \left(0\text{ if } B_{n-1} - B_{m-1} = K \text{ else } dp[m] \right) \\ &= \sum_{1 \leq m \lt n, B_{m-1} \neq B_{n-1} - K} dp[m]. \end{aligned} \]

In other words, whether \(dp[m]\) contributes to \(dp[n]\) is solely dependent on the value of \(B_{m-1}\).

So we manage the DP table in an associative array (std::map in C++). Define an associative array \(M\) by

  • \(M[x]\) : the sum of \(dp[m]\) over \(m\) with \(B_{m-1} = x\).

Also, manage the sum of \(dp[m]\) in a variable \(all\) also. Then the transition of the DP can be written as

\[dp[n] = all - M[B_{n-1} - K],\]

which can be evaluated fast; \(M\) and \(all\) can also be updated fast too.

By implementing the algorithm above appropriately, the problem can be solved in a total of \(\mathrm{O}(N \log N)\) or \(\mathrm{O}(N)\) time, which is fast enough.

The following is sample code in C++.

#include <iostream>
#include <map>
#include <vector>
using namespace std;

#include "atcoder/modint.hpp"
using mint = atcoder::modint998244353;

int main() {
  long long N, K;
  cin >> N >> K;
  vector<long long> A(N);
  for (auto& a : A) cin >> a;

  map<long long, mint> M;
  M[0] = 1;
  mint all = 1;
  long long acc = 0;

  for (int i = 0; i < N; i++) {
    acc += A[i];
    long long ban = acc - K;
    mint cur = all - M[ban];
    M[acc] += cur, all += cur;
    if (i + 1 == N) cout << cur.val() << "\n";
  }
}

posted:
last update: