公式

E - Subset Product Problem 解説 by sounansya


この問題の解法として、(TLE になりますが)以下のような簡単なアルゴリズムが存在します。

  • \(d_k[x]\) を「\(A_i = x\) を満たす整数 \(1\le i\le k\) 全てに対する \(A_i\) の総積」と定義する。
  • \(d_{k-1}[x]\) から \(d_k[x]\) への更新は in-place に更新することで \(O(1)\) で行うことができる。
  • \(k=k_0\) の場合の答えは \(\displaystyle \prod_{x\subset A_{k_0}} d[x]\) となる。この値は \(O\left(2^{\mathrm{popcount}\left(A_{k_0}\right)}\right)\) で計算できる。

しかし、このアルゴリズムの計算量は最悪ケースで \(\Theta(N \max A)\) となり TLE となります。このアルゴリズムを改善することを考えます。


\(0\le p,q < 2^{10}\) に対し \((p,q)=2^{10}p+q\) とします。\(0\le x < 2^{20}\) を満たす整数 \(x\)\(0\le p,q < 2^{10}\) を満たす整数の組 \((p,q)\) が上の写像で一対一に対応することに注意してください。以降はこれらを同一視します。また、 \(A_k=(P_k,Q_k)\) として \(P_k,Q_k\) し、 \(1_{[\text{cond}]}\)\(\text{cond}\) が真なら \(1\) を、偽なら \(0\) を返す関数とします。

さらに、 \(x\ \mathrm{OR}\ y = y\) が成り立つことを \(x \subset y\) と表記します。

\(0\le p,q < 2^{10}\) に対し \(\displaystyle d_k[(p,q)]=\prod_{1\le i\le k} A_i^{1_{[P_i \subset p \land Q_i = q]}}\) と定義します。初期値は \(d_0[(p,q)]=1\) です。

このように \(d\) を定義すると、 \(d\) の更新と答えの計算は以下のようにできます。

\(d_{k-1}\) から \(d_k\) の更新

\(d_k\)\(d_{k-1}\) で in-place に更新した後、 \(p \supset P_k\) に対し \(d_k[(p,Q_k)]\)\(A_k\) を掛ければ良いです。

答えの計算

答えとなる \(\displaystyle \prod_{\substack{1\le i\le k\\ A_i \subset A_k} }A_i\) を式変形していきます。

\[ \begin{aligned} &\phantom{=} \prod_{\substack{1\le i\le k\\ A_i \subset A_k} }A_i\\ &= \prod_{1\le i\le k} A_i^{1_{[P_i \subset P_k \land Q_i \subset Q_k]}} \\ &=\prod_{q \subset Q_k} \prod_{1\le i\le k} A_i^{1_{[P_i \subset P_k \land Q_i =q]}} \\ &= \prod_{q\subset Q_k} d_k[(P_k , q)] \end{aligned} \]

したがって、 \(\displaystyle \prod_{q\subset Q_k} d_k[(P_k , q)]\) が答えとなります。


\(d_{k-1}\) から \(d_k\) への更新と答えの計算の両方 \(\sqrt{\max A}=2^{10}\) 回の計算で行うことができます。したがって、全体としての計算量は \(O(N\sqrt{\max A})\) となります。

以上を適切に実装することでこの問題に正答することができます。

実装例(Python3)

n = int(input())
INF = 1 << 10
MOD = 998244353
d = [1] * (1 << 20)
ans = [0] * n
cnt = 0
for x in list(map(int, input().split())):
    x1, x2 = x >> 10, x & (INF - 1)
    for i in range(INF):
        if (x1 & i) == x1:
            d[(i << 10) | x2] = x * d[(i << 10) | x2] % MOD
    res = 1
    for i in range(INF):
        if (x2 & i) == i:
            res = res * d[(x1 << 10) | i] % MOD
    ans[cnt] = res
    cnt += 1
print(*ans, sep="\n")

投稿日時:
最終更新: