公式

G - Linear Inequation 解説 by en_translator


For a sequence of positive integers \(A=(A _ 1,A _ 2,\ldots,A _ N)\) and a positive integer \(M\), write the answer for this problem as \(f(A,M)\).

Also, for a sequence of non-negative integers \(A=(A _ 1,A _ 2,\ldots,A _ N)\) and \(B=(B _ 1,B _ 2,\ldots,B _ N)\), define \(A\times B\) as the non-negative integer \(\displaystyle\sum _ {i=1} ^ NA _ iB _ i\), and \(A+B\) as the sequence of non-negative integers \((A _ 1+B _ 1,A _ 2+B _ 2,\ldots,A _ N+B _ N)\). Also, for a sequence of non-negative integers \(A=(A _ 1,A _ 2,\ldots,A _ N)\) and \(c\), let \(c\) denote the sequence of non-negative integers \((cA _ 1,cA _ 2,\ldots,cA _ N)\).

Note that \(A\times(xB+yC)=x(A\times B)+y(A\times C)\) holds for all \(A,B,C,x,y\).


Fix a positive integer \(d\). For all sequences of non-negative integers \(X=(X _ 1,X _ 2,\ldots,X _ N)\), there uniquely exists a sequence of non-negative integers \(Q=(Q _ 1,Q _ 2,\ldots,Q _ N),R=(R _ 1,R _ 2,\ldots,R _ N)\) such that:

  • \(X=dQ+R\),
  • \(0\le R _ i\lt d\ (1\le i\le N)\).

Therefore, counting the number of sequences of non-negative integers \(X\) satisfying the conditions is equivalent to finding the number of pairs of sequences of integers \((Q,R)\) that satisfies the conditions (reinterpreted appropriately).

What we want to find is the number of sequences \(X\) with \(A\times X\leq M\), so we will count the number of pairs \((Q,R)\) with \(d(A\times Q)+A\times R\leq M\). Let \(S\) be the multiset of values \(A\times R\) for all sequences of \(R\) whose elements are the integers between \(0\) and \(d-1\). Then \[f(A,M)=\sum _ {s\in S}f\!\left(\!A,\left\lfloor\dfrac{M-s}d\right\rfloor\right).\]

Given a set of non-negative integers \(\lbrace c _ i\rbrace _ {i\in\mathbb Z}\) such that \(c _ i\ne 0\) for a finite number of indices \(i\), consider how to evaluate \(\displaystyle\sum _ {i\in\mathbb Z}c _ if(A,i)\). Plugging the equation above, this can be transformed as \[\begin{aligned}\sum _ {i\in\mathbb Z}c _ if(A,i)&=\sum _ {i\in\mathbb Z}c _ i\sum _ {s\in S}f\!\left(\!A,\left\lfloor\dfrac{i-s}d\right\rfloor\right)\\&=\sum _ {i\in\mathbb Z}\left(\sum _ {j=0} ^ {d-1}\sum _ {s\in S}c _ {id+j+s}\right)f(A,i)\end{aligned}.\]

Consider how to reinterpret \(\displaystyle\sum _ {j=0} ^ {d-1}\sum _ {s\in S}c _ {id+j+s}\) as a new \(\lbrace c _ i\rbrace _ {i\in\mathbb Z}\). This operation can be done in \(O((L _ 0+L _ 1)\log(L _ 0+L _ 1))\) time by managing \(\lbrace c _ i\rbrace\) and \(S\) appropriately, where \(\displaystyle L _ 0\coloneqq\max _ {c _ i\ne0}i-\min _ {c _ i\ne0}i\) and \(L _ 1\coloneqq\max _ {s\in S}s\). The maximum \(i\) with \(c _ i\ne0\) becomes \(\dfrac1d\) times larger by an operation, so one can start from \(c _ M=1\) and \(c _ i=0\ (i\ne M)\) and repeat the operation \(\lceil\log _ dM\rceil\) times to the indices \(i\) with \(c _ i\ne0\) to \(i\le0\). Since \(f(A,i)=0\ (i\lt0)\) and \(f(A,0)=1\), the value \(c _ 0\) at this point is the sought value.

Obviously, \(\displaystyle L _ 1=(d-1)\sum _ iA _ i\). When starting from \(c _ M=1\) and \(c _ i=0\ (i\ne M)\) and repeat the operation above, one can show that \(L _ 0\le L _ 1\dfrac d{d-1}\) always holds, so \(L _ 0+L _ 1\le(2d-1)\displaystyle\sum _ iA _ i\). Hence, the problem can be solved fast enough by appropriately taking an arbitrary positive integer \(d\) (that is \(2\) or greater).

The following is sample code. In this code, we spend \(O(d ^ 2\max A _ i\sum A _ i\log(d\sum A _ i))\) time to find \(S\), which is already fast enough, but has a room of improvement.

#include <iostream>
#include <vector>
#include <ranges>
#include <atcoder/convolution>

int main() {
    using namespace std;
    using modint = atcoder::static_modint<998244353>;
    unsigned N;
    unsigned long M;
    cin >> N >> M;
    constexpr unsigned d{2};
    vector<modint> S{1}; // S[i] = #(sequences x of non-negative integers between 0 and d-1 such that i = ∑ A_jx_j)
    for (const auto A : views::istream<unsigned>(cin)) {
        vector<modint> tmp(A * (d - 1) + 1);
        for (auto&& t : tmp | views::chunk(A))
            t[0] = 1;
        S = convolution(move(S), move(tmp));
    }
    vector<modint> coef{1}; // Maintain so that ∑ coef[M - i]f(A, i) is the answer
    while (M) {
        vector<modint> result(M / d - (M - min(M, size(S) + size(coef) - 2)) / d + 1);
        for (const auto& [x, i] : views::zip(convolution(S, coef), views::iota(0)) // Convolve
                                | views::take(M + 1)) // and eliminate the negative part
            result[M / d - (M - i) / d] += x;
        swap(coef, result);
        M /= d;
    }
    cout << coef.front().val() << endl; // coef[0] is the answer
    return 0;
}

投稿日時:
最終更新: