Official

E - LEQ Editorial by penguinman


\(A\) の要素が互いに異なると仮定したとき、任意の整数 \(i,j\ (1 \leq i \lt j \leq N)\) について以下が成り立ちます。

  • \(A'_1=A_i\) かつ \(A'_k=A_j\) が成り立つような \(A\) の連続するとは限らない部分列 \(A'=(A'_1,A'_2,\ldots,A'_k)\) は、\(2^{j-i-1}\) 個存在する。

故に、\(A\) の要素が互いに異なると仮定した場合答えは以下の通りになります。

  • \(1 \leq i \lt j \leq N\) かつ \(A_i \leq A_j\) を満たすすべての整数対 \((i,j)\) に対する、\(2^{j-i-1}\) の総和

\(2^{j-i-1}=\frac{2^{j-1}}{2^i}\) より、以下のことが分かります。

  • \(j=1,2,\ldots,N\) について、\(B_j=(1 \leq i \lt j\) かつ \(A_i \leq A_j\) を満たすようなすべての整数 \(i\) に対する、\(\frac{1}{2^i}\) の総和\()\) と定義する。このとき、答えは \(\sum_{j=1}^{N} B_j \times 2^{j-1}\) である。

よって \(j=1,2,\ldots,N\) について \(B_j\) が求められてさえいれば、答えを求めることは容易です。

\(B_j\) を求めること自体はそこまで難しくはなく、\(A\) の要素を座標圧縮した上で Binary Indexed Tree 等を用いて添字の昇順に走査していくことで \(O(N \log N)\) で求値可能です。ここで \(\frac{1}{2^i}\) を小数等で管理するのは困難であるため、\(\text{mod}\ 998244353\) 上での逆元を用いて管理することが推奨されます。

\(A\) の要素が互いに異ならない場合においても、似たような手法により答えを求めることが可能です。よってこの問題を解くことができました。

計算量は実装により \(O(N \log N)\)\(O(N \log \text{mod})\) となります(後者は逆元の計算を毎回行った場合の計算量)。

実装例 (C++)

#include<bits/stdc++.h>
using namespace std;

using ll = long long;

const ll mod = 998244353;

struct binary_indexed_tree{
    int N;
    vector<ll> bit;
    binary_indexed_tree(int n):N(n){
        bit.resize(N+1,0);
    }
    ll addition(ll x, ll y){
        return (x+y)%mod;
    }
    void add(int x,ll a){
        x++;
        for(x; x<=N; x+=(x&-x)) bit[x] = addition(bit[x],a);
    }
    ll sum(int x){
        x++;
        ll ret=0;
        for(x; x>0; x-=(x&-x)) ret = addition(ret,bit[x]);
        return ret;
    }
};

ll modpow(ll x, ll y){
    ll ret = 1;
    while(0 < y){
        if(y & 1){
            ret *= x;
            ret %= mod;
        }
        x *= x;
        x %= mod;
        y >>= 1;
    }
    return ret;
}

int comp(vector<int> &A){
    std::map<int,int> mem;
    for(auto p: A) mem[p] = 0;
    int sz = 0;
    for(auto &p: mem) p.second = sz++;
    for(auto &p: A) p = mem[p];
    return sz;
}

int main(){
    const ll div = modpow(2,mod-2);
    int N; cin >> N;
    vector<int> A(N);
    for(int i=0; i<N; i++) cin >> A[i];
    int n = comp(A);
    binary_indexed_tree bit(n);
    ll ans = 0;
    for(int i=0; i<N; i++){
        ans += bit.sum(A[i])*modpow(2,i);
        ans %= mod;
        bit.add(A[i],modpow(div,i+1));
    }
    cout << ans << endl;
}

posted:
last update: