L - Long Sequence Inversion 2 Editorial by shinchan


概要

\(P\) の値の降順に桁を見ていくことにします。\(P\) の値の降順に桁を見ていくとき、最初に見る桁は、その桁でほとんど大小が決まります(その桁の値が異なるとき、他の値によって大小が変化しません)。

以降、「\(i\) 桁目」と言うときは 「 \(i\) 番目に見た桁」ではなく「下から \(i\) 桁目」の意味で使っています。

\(i\) 桁目の値が違う場合はその時点で答えを求め、同じ場合、さらに別の桁について求めていくという方針をとります。

0-indexedで下から \(i\) 桁目を見るとき、まず値を以下のように定義します。

  • \(x :=\) \(i\) 桁目より小さい桁でまだ選ばれてないものの個数

  • \(y :=\) \(i\) 桁目より大きい桁でまだ選ばれてないものの個数

  • \(z :=\) ( \(i\) を除いて)既に選ばれている桁の個数

簡単のため、\(z = 0\) について説明します(つまり \(P_i\) が最大となる桁 \(i\) です)。

\(i+1\) 桁目以降が同じところについて、\(i\) 桁目が \(0\) の大きさ \(B^x\) の連続した固まり、\(i\) 桁目が \(1\) の大きさ \(B^x\) の連続した固まり、…、\(i\) 桁目が \(B-1\) の大きさ \(B^x\) の連続した固まりが並んでいるイメージです。\(i+1\) 桁目以降を考えると、この大きさ \(B^x\) の固まりの \(B\) 個の固まりが、 \(B^y\) 個並んでいます。

\(i + 1\) 桁目以降が同じ \(2\) 項の転倒数

比較する \(2\) 項の \(i + 1\) 桁目以降が同じ場合については、\(i\) 桁目の添字の大小と \(V_i\) の大小の両方を比較する必要があります。

\(i+1\) 桁目以降を固定する場合の数は \(B^y\) であり、\(i\) 桁目よりも下の桁を考慮すると、さきほどの大きさ \(B^x\) のブロックから \(1\) つずつ取り出すことになり、 \(B^x \times B^x\) をかけます。そして、既に固定されている桁数を考慮して \(B^z\) をかけます。

よって、\(i+1\) 桁目以降が同じで \(i\) 桁目が異なる \(2\) 項についての転倒数の総和は、 \(V_i\) の転倒数 ( \(T_i\) とする)を使って、 \(T_i B^y (B^x)^2 B^z\) と書けます。

\(i + 1\) 桁目以降が異なる \(2\) 項の転倒数

比較する \(2\) 項の \(i + 1\) 桁目以降が異なる場合については、添字の大小をあまり気にする必要がありません。( \(i+1\) 桁目以降が異なる場合、\(i\) 桁目とそれより下の桁の大小によって、添字の大小が変化しません)。

イメージして頂きたいのが、\(1\) から \(N\) の順列を \(2\) つ連結させたとして、前半 \(N\) 項から \(1\) つ、後半 \(N\) 項から \(1\) つ選ぶようにして計算した転倒数(つまり、前半同士、後半同士は無視する)は、\(\frac{N(N-1)}{2}\) です。これは、前半の要素 \(1\) よりも小さい要素は後半に \(0\) 個存在、前半の要素 \(2\) よりも小さい要素は後半に \(1\) 個存在…のように考えるとわかります。

今回の場合、\(i+1\) 桁目以降を固定し、\(i\) 桁目を動かしたときの転倒数への寄与は、\( \frac{B(B-1)}{2}\) となります。

\(i+1\) 桁目以降を選ぶ場合の数は \(_{B^y} C_2\) です。\(i + 1\) 桁目以降が同じ場合と同様に、大きさ \(B^x\) のブロックから \(1\) つずつ取り出すことになり、 \(B^x \times B^x\) をかけます。そして、既に固定されている桁数を考慮して \(B^z\) をかけます。

よって、\(i+1\) 桁目以降が異なり、 \(i\) 桁目が異なる \(2\) 項についての転倒数の総和は、 \(_{B^y} C_2 (B^x)^2 \frac{B(B-1)}{2} B^z\) と書けます。

まとめ

上記 \(2\) つを、桁を \(P\) の降順に見て足し合わせていけば答えが求まります。

下から \(i\) 桁目について、\(T_i B^y (B^x)^2 B^z + _{B^y} C_2 (B^x)^2 \frac{B(B-1)}{2} B^z\) が求めるものなので、変形すると、以下のようになります(\(T_i\)\(V_i\) の転倒数)。

\[\left(T_i + \frac{B(B-1) (B^y - 1)}{4} \right) B^{y + 2x + z}\]

実装例

C++, 124ms https://atcoder.jp/contests/ttpc2024_1/submissions/60810775

#include <bits/stdc++.h>
using namespace std;
#define all(v) (v).begin(),(v).end()
#define pb(a) push_back(a)
#define rep(i, n) for(int i=0;i<n;i++)
#define foa(e, v) for(auto&& e : v)
using ll = long long;
const ll mod = 998244353;

template<int MOD> struct Modint {
    long long val;
    constexpr Modint(long long v = 0) noexcept : val(v % MOD) { if (val < 0) val += MOD; }
    constexpr int mod() const { return MOD; }
    constexpr long long value() const { return val; }
    constexpr Modint operator - () const noexcept { return val ? MOD - val : 0; }
    constexpr Modint operator + (const Modint& r) const noexcept { return Modint(*this) += r; }
    constexpr Modint operator - (const Modint& r) const noexcept { return Modint(*this) -= r; }
    constexpr Modint operator * (const Modint& r) const noexcept { return Modint(*this) *= r; }
    constexpr Modint operator / (const Modint& r) const noexcept { return Modint(*this) /= r; }
    constexpr Modint& operator += (const Modint& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Modint& operator -= (const Modint& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Modint& operator *= (const Modint& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Modint& operator /= (const Modint& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Modint& r) const noexcept { return this->val == r.val; }
    constexpr bool operator != (const Modint& r) const noexcept { return this->val != r.val; }
    friend constexpr istream& operator >> (istream& is, Modint<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Modint<MOD>& x) noexcept {
        return os << x.val;
    }
    constexpr Modint<MOD> pow(long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return this->pow(-n).inv();
        Modint<MOD> ret = pow(n >> 1);
        ret *= ret;
        if (n & 1) ret *= *this;
        return ret;
    }
    constexpr Modint<MOD> inv() const noexcept {
        long long a = this->val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Modint<MOD>(u);
    }
};

const int MOD = mod;
using mint = Modint<MOD>;

template<class T> struct BIT {
    int n;
    vector<T> a;
    BIT(int m) :n(m), a(m + 1, 0) {}
    void add(int x, T y) {
        x ++;
        while(x <= n) {
            a[x] += y;
            x += x & -x;
        }
    }
    T sum(int x) {
        T r = 0;
        while(x > 0) {
            r += a[x];
            x -= x & -x;
        }
        return r;
    }
    T sum(int x, int y) {
        return sum(y) - sum(x);
    }
};

int main() {
    ll L, B;
    cin >> L >> B;
    vector<ll> ord(L, 0);
    rep(i, L) {
        ll x; cin >> x;
        ord[x] = i;
    }
    reverse(all(ord)); // 降順

    vector v(L, vector(B, 0LL));
    rep(i, L) rep(j, B) cin >> v[i][j];
    BIT<ll> notused(L); // まだ選ばれてない桁が1
    rep(i, L) notused.add(i, 1);

    auto tentou = [&](vector<ll> vec) -> ll {
        ll sum = 0;
        ll n = vec.size();
        BIT<ll> bit(n);
        rep(i, n) {
            sum += bit.sum(vec[i] + 1, n);
            bit.add(vec[i], 1);
        }
        return sum;
    };

    ll used = 0;
    mint ans = 0;
    for(int i : ord) {
        ll x = notused.sum(0, i);
        ll y = notused.sum(i + 1, L);
        ll t = tentou(v[i]);
        ans += mint(B).pow(used + x * 2 + y) * (mint(t) + (mint(B).pow(y) - mint(1)) * mint(B - 1) * mint(B) / mint(4));
        notused.add(i, -1);
        used ++;
    }
    cout<< ans << endl;
}

posted:
last update: