Official

E - E [max] Editorial by yuto1115

解説

\(N\) 個のサイコロのいずれかの面に書かれている数を小さい順に全て並べた列を \(S=(S_1,S_2,\dots,S_k)\) とおきます。また、便宜上 \(S_0=0\) とおきます。

サイコロ \(i\) の上を向いた面に書かれた数を表す確率変数を \(X_i\) とおくと、求める期待値 \(E\) は以下のように表されます。(\(P[*]\)\(*\) が起こる確率を表します。)

\[E=\displaystyle \sum_{i=1}^{k} S_{i}\times P[\max_{j=1..N}X_j=S_i]\]

競技プログラミングにおける確率・期待値の問題では、「何かの最大値が特定の値に一致する確率を求めるより、最大値が特定の値以下になる確率を求める方が往々にして簡単である」ということを利用して式変形を行うことがしばしばあります。本問題でもこのテクニックを用いることができ、

\[ \displaystyle \begin{aligned} E &=\sum_{i=1}^{k} S_{i}\times P[\max_{j}X_j=S_i] \\ &= \sum_{i=1}^{k} S_{i}\times \left(P[\max_{j}X_j\leq S_i]-P[\max_{j}X_j\leq S_{i-1}]\right) \\ &= S_k -\sum_{i=1}^{k-1} (S_{i+1}-S_i)\times P[\max_{j}X_j\leq S_i]\\ &= S_k -\sum_{i=1}^{k-1} (S_{i+1}-S_i)\prod_{j=1}^{N}\frac{B_i^{(j)}}{6}\\ \end{aligned} \]

となります。ここで、\(B_i^{(j)}\) はサイコロ \(j\)\(6\) つの面のうち書かれた数が \(S_i\) 以下であるようなものの数を表します。

全ての \(i,j\) に対する \(B_i^{(j)}\) の値を配列として保持することはできません。しかし、\(i=1,2,\dots,k\) の順に見ていきながら、現在の \(i\) に対する \(B_{i}^{(j)}\ (1\leq j\leq N)\) の値および \(\displaystyle \prod_{j=1}^{N}B_i^{(j)}\) の値のみを保持することにすると、更新操作(\(B_i^{(j)}\neq B_{i+1}^{(j)}\) なる \(i,j\) の組に対し、\(B_i^{(j)} \leftarrow B_{i+1}^{(j)}\) とする操作)は \(O(N)\) 回しか行われないため、十分高速に処理することができます。

全体の計算量は、実装方針に応じて \(O(N(\log N + \log\mathrm{MOD}))\)\(O(N\log N + \log \mathrm{MOD})\) などになりますが、いずれにおいても十分高速です。

実装例 (C++) :

#include <bits/stdc++.h>
#include <atcoder/modint>

using namespace std;

using mint = atcoder::modint998244353;

int main() {
    int n;
    cin >> n;
    vector<vector<int>> a(n, vector<int>(6));
    vector<int> s;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < 6; j++) {
            cin >> a[i][j];
            s.push_back(a[i][j]);
        }
    }
    sort(s.begin(), s.end());
    s.erase(unique(s.begin(), s.end()), s.end());
    int k = s.size();
    vector<vector<int>> upd(k);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < 6; j++) {
            int id = lower_bound(s.begin(), s.end(), a[i][j]) - s.begin();
            upd[id].push_back(i);
        }
    }

    mint ans = 0;
    vector<int> b(n);
    mint prod = 1;
    int zero_cnt = n;
    for (int i = 0; i < k - 1; i++) {
        for (int j: upd[i]) {
            if (!b[j]) {
                --zero_cnt;
            } else {
                prod /= b[j];
            }
            ++b[j];
            prod *= b[j];
        }
        ans -= (zero_cnt ? 0 : prod) * (s[i + 1] - s[i]);
    }
    ans /= mint(6).pow(n);
    ans += s[k - 1];
    cout << ans.val() << endl;
}

posted:
last update: