L - Long Sequence Inversion 2 Editorial
by
shinchan
概要
\(P\) の値の降順に桁を見ていくことにします。\(P\) の値の降順に桁を見ていくとき、最初に見る桁は、その桁でほとんど大小が決まります(その桁の値が異なるとき、他の値によって大小が変化しません)。
以降、「\(i\) 桁目」と言うときは 「 \(i\) 番目に見た桁」ではなく「下から \(i\) 桁目」の意味で使っています。
\(i\) 桁目の値が違う場合はその時点で答えを求め、同じ場合、さらに別の桁について求めていくという方針をとります。
0-indexedで下から \(i\) 桁目を見るとき、まず値を以下のように定義します。
\(x :=\) \(i\) 桁目より小さい桁でまだ選ばれてないものの個数
\(y :=\) \(i\) 桁目より大きい桁でまだ選ばれてないものの個数
\(z :=\) ( \(i\) を除いて)既に選ばれている桁の個数
簡単のため、\(z = 0\) について説明します(つまり \(P_i\) が最大となる桁 \(i\) です)。
\(i+1\) 桁目以降が同じところについて、\(i\) 桁目が \(0\) の大きさ \(B^x\) の連続した固まり、\(i\) 桁目が \(1\) の大きさ \(B^x\) の連続した固まり、…、\(i\) 桁目が \(B-1\) の大きさ \(B^x\) の連続した固まりが並んでいるイメージです。\(i+1\) 桁目以降を考えると、この大きさ \(B^x\) の固まりの \(B\) 個の固まりが、 \(B^y\) 個並んでいます。
\(i + 1\) 桁目以降が同じ \(2\) 項の転倒数
比較する \(2\) 項の \(i + 1\) 桁目以降が同じ場合については、\(i\) 桁目の添字の大小と \(V_i\) の大小の両方を比較する必要があります。
\(i+1\) 桁目以降を固定する場合の数は \(B^y\) であり、\(i\) 桁目よりも下の桁を考慮すると、さきほどの大きさ \(B^x\) のブロックから \(1\) つずつ取り出すことになり、 \(B^x \times B^x\) をかけます。そして、既に固定されている桁数を考慮して \(B^z\) をかけます。
よって、\(i+1\) 桁目以降が同じで \(i\) 桁目が異なる \(2\) 項についての転倒数の総和は、 \(V_i\) の転倒数 ( \(T_i\) とする)を使って、 \(T_i B^y (B^x)^2 B^z\) と書けます。
\(i + 1\) 桁目以降が異なる \(2\) 項の転倒数
比較する \(2\) 項の \(i + 1\) 桁目以降が異なる場合については、添字の大小をあまり気にする必要がありません。( \(i+1\) 桁目以降が異なる場合、\(i\) 桁目とそれより下の桁の大小によって、添字の大小が変化しません)。
イメージして頂きたいのが、\(1\) から \(N\) の順列を \(2\) つ連結させたとして、前半 \(N\) 項から \(1\) つ、後半 \(N\) 項から \(1\) つ選ぶようにして計算した転倒数(つまり、前半同士、後半同士は無視する)は、\(\frac{N(N-1)}{2}\) です。これは、前半の要素 \(1\) よりも小さい要素は後半に \(0\) 個存在、前半の要素 \(2\) よりも小さい要素は後半に \(1\) 個存在…のように考えるとわかります。
今回の場合、\(i+1\) 桁目以降を固定し、\(i\) 桁目を動かしたときの転倒数への寄与は、\( \frac{B(B-1)}{2}\) となります。
\(i+1\) 桁目以降を選ぶ場合の数は \(_{B^y} C_2\) です。\(i + 1\) 桁目以降が同じ場合と同様に、大きさ \(B^x\) のブロックから \(1\) つずつ取り出すことになり、 \(B^x \times B^x\) をかけます。そして、既に固定されている桁数を考慮して \(B^z\) をかけます。
よって、\(i+1\) 桁目以降が異なり、 \(i\) 桁目が異なる \(2\) 項についての転倒数の総和は、 \(_{B^y} C_2 (B^x)^2 \frac{B(B-1)}{2} B^z\) と書けます。
まとめ
上記 \(2\) つを、桁を \(P\) の降順に見て足し合わせていけば答えが求まります。
下から \(i\) 桁目について、\(T_i B^y (B^x)^2 B^z + _{B^y} C_2 (B^x)^2 \frac{B(B-1)}{2} B^z\) が求めるものなので、変形すると、以下のようになります(\(T_i\) は \(V_i\) の転倒数)。
\[\left(T_i + \frac{B(B-1) (B^y - 1)}{4} \right) B^{y + 2x + z}\]
実装例
C++, 124ms https://atcoder.jp/contests/ttpc2024_1/submissions/60810775
#include <bits/stdc++.h>
using namespace std;
#define all(v) (v).begin(),(v).end()
#define pb(a) push_back(a)
#define rep(i, n) for(int i=0;i<n;i++)
#define foa(e, v) for(auto&& e : v)
using ll = long long;
const ll mod = 998244353;
template<int MOD> struct Modint {
long long val;
constexpr Modint(long long v = 0) noexcept : val(v % MOD) { if (val < 0) val += MOD; }
constexpr int mod() const { return MOD; }
constexpr long long value() const { return val; }
constexpr Modint operator - () const noexcept { return val ? MOD - val : 0; }
constexpr Modint operator + (const Modint& r) const noexcept { return Modint(*this) += r; }
constexpr Modint operator - (const Modint& r) const noexcept { return Modint(*this) -= r; }
constexpr Modint operator * (const Modint& r) const noexcept { return Modint(*this) *= r; }
constexpr Modint operator / (const Modint& r) const noexcept { return Modint(*this) /= r; }
constexpr Modint& operator += (const Modint& r) noexcept {
val += r.val;
if (val >= MOD) val -= MOD;
return *this;
}
constexpr Modint& operator -= (const Modint& r) noexcept {
val -= r.val;
if (val < 0) val += MOD;
return *this;
}
constexpr Modint& operator *= (const Modint& r) noexcept {
val = val * r.val % MOD;
return *this;
}
constexpr Modint& operator /= (const Modint& r) noexcept {
long long a = r.val, b = MOD, u = 1, v = 0;
while (b) {
long long t = a / b;
a -= t * b, swap(a, b);
u -= t * v, swap(u, v);
}
val = val * u % MOD;
if (val < 0) val += MOD;
return *this;
}
constexpr bool operator == (const Modint& r) const noexcept { return this->val == r.val; }
constexpr bool operator != (const Modint& r) const noexcept { return this->val != r.val; }
friend constexpr istream& operator >> (istream& is, Modint<MOD>& x) noexcept {
is >> x.val;
x.val %= MOD;
if (x.val < 0) x.val += MOD;
return is;
}
friend constexpr ostream& operator << (ostream& os, const Modint<MOD>& x) noexcept {
return os << x.val;
}
constexpr Modint<MOD> pow(long long n) noexcept {
if (n == 0) return 1;
if (n < 0) return this->pow(-n).inv();
Modint<MOD> ret = pow(n >> 1);
ret *= ret;
if (n & 1) ret *= *this;
return ret;
}
constexpr Modint<MOD> inv() const noexcept {
long long a = this->val, b = MOD, u = 1, v = 0;
while (b) {
long long t = a / b;
a -= t * b, swap(a, b);
u -= t * v, swap(u, v);
}
return Modint<MOD>(u);
}
};
const int MOD = mod;
using mint = Modint<MOD>;
template<class T> struct BIT {
int n;
vector<T> a;
BIT(int m) :n(m), a(m + 1, 0) {}
void add(int x, T y) {
x ++;
while(x <= n) {
a[x] += y;
x += x & -x;
}
}
T sum(int x) {
T r = 0;
while(x > 0) {
r += a[x];
x -= x & -x;
}
return r;
}
T sum(int x, int y) {
return sum(y) - sum(x);
}
};
int main() {
ll L, B;
cin >> L >> B;
vector<ll> ord(L, 0);
rep(i, L) {
ll x; cin >> x;
ord[x] = i;
}
reverse(all(ord)); // 降順
vector v(L, vector(B, 0LL));
rep(i, L) rep(j, B) cin >> v[i][j];
BIT<ll> notused(L); // まだ選ばれてない桁が1
rep(i, L) notused.add(i, 1);
auto tentou = [&](vector<ll> vec) -> ll {
ll sum = 0;
ll n = vec.size();
BIT<ll> bit(n);
rep(i, n) {
sum += bit.sum(vec[i] + 1, n);
bit.add(vec[i], 1);
}
return sum;
};
ll used = 0;
mint ans = 0;
for(int i : ord) {
ll x = notused.sum(0, i);
ll y = notused.sum(i + 1, L);
ll t = tentou(v[i]);
ans += mint(B).pow(used + x * 2 + y) * (mint(t) + (mint(B).pow(y) - mint(1)) * mint(B - 1) * mint(B) / mint(4));
notused.add(i, -1);
used ++;
}
cout<< ans << endl;
}
posted:
last update: