Official

E - Dice Product 3 Editorial by Nyaan


DP で解くことを考えます。

\(\mathrm{dp}(n)\) : (持っている数が \(n\) である状態からスタートしたときに, 最終的に持っている数が \(N\) に一致する確率)

と定義すると、答えは \(\mathrm{dp}(1)\) です。

\(\mathrm{dp}(n)\) に関する式を求めます。\(n\) を持っている状態から 1 回サイコロを振ると、持っている数は \(n, 2n, 3n, 4n, 5n, 6n\) のいずれかに等確率に変化します。よって

\[\mathrm{dp}(n) = \frac{1}{6} (\mathrm{dp}(n) + \mathrm{dp}(2n) + \mathrm{dp}(3n) + \mathrm{dp}(4n) + \mathrm{dp}(5n) + \mathrm{dp}(6n) )\]

です。この式は両辺に \(\mathrm{dp}(n)\) が含まれているので式変形によって取り除くと

\[\mathrm{dp}(n) - \frac{1}{6} \mathrm{dp}(n) = \frac{1}{6} (\mathrm{dp}(2n) + \mathrm{dp}(3n) + \mathrm{dp}(4n) + \mathrm{dp}(5n) + \mathrm{dp}(6n) )\]

\[\frac{5}{6} \mathrm{dp}(n) = \frac{1}{6} (\mathrm{dp}(2n) + \mathrm{dp}(3n) + \mathrm{dp}(4n) + \mathrm{dp}(5n) + \mathrm{dp}(6n) )\]

\[\mathrm{dp}(n) = \frac{1}{5} (\mathrm{dp}(2n) + \mathrm{dp}(3n) + \mathrm{dp}(4n) + \mathrm{dp}(5n) + \mathrm{dp}(6n) )\]

となり、\(\mathrm{dp}(n)\) に関する式を得ることができました。

また、\(\mathrm{dp}(N) = 1\) であり、さらに \(n \geq N\) のとき \(\mathrm{dp}(n) = 0\) です。以上の条件を全て使うと、\(\mathrm{dp}(n)\) をメモ化再帰で計算するアルゴリズムが有限の時間で終了することがわかります。

  • この式は左辺が \(\mathrm{dp}(n)\) に関する式で右辺が \(\mathrm{dp}(kn)\) \((2 \leq k \leq 6)\) に関する式なので、再帰が深くなるごとに呼び出される値は大きくなります。ここから、再帰の最中にループが発生することはありえないのが確認できます。

計算量を考えます。メモ化再帰で訪問する \(\mathrm{dp}(n)\) の種類を \(X\) とすると, std::map 等を用いたメモ化再帰が \(\mathrm{O}(X \log X)\) 程度で動作して十分高速であると言えるので、この \(X\) の値の上界を評価します。
サイコロの目は \(2, 3, 5\) の倍数のみが書かれているので、\(n\) は非負整数の組 \((a, b, c)\) を用いて \(n = 2^a 3^b 5^c\) と表せる数に限ります。\(n \leq N\) という条件と合わせると

\[ \begin{aligned} &0 \leq a \leq \log_2 N \leq \log_2(10^{18}) < 60 \\ &0 \leq b \leq \log_3 N \leq \log_3(10^{18}) < 38 \\ &0 \leq c \leq \log_5 N \leq \log_5(10^{18}) < 26 \end{aligned} \]

が言えるので、\(n\) としてあり得るのは高々\(60 \times 38 \times 26 = 59280\) 通りであることが証明できます。よって、\(X \leq 59280\) が証明できたため、メモ化再帰による DP の計算は高速に動作することが言えました。

  • 実装例(C++)
#include <iostream>
#include <map>
using namespace std;
#include "atcoder/modint.hpp"
using mint = atcoder::modint998244353;

long long N;
map<long long, mint> memo;
mint dp(long long n) {
  if (n >= N) return n == N ? 1 : 0;
  if (memo.count(n)) return memo[n];
  mint res = 0;
  for (int i = 2; i <= 6; i++) res += dp(i * n);
  return memo[n] = res / 5;
}
int main() {
  cin >> N;
  cout << dp(1).val() << "\n";
}

posted:
last update: