公式

F - Beautiful Kadomatsu 解説 by physics0523

解説(後半)

解説の前半を先にお読みください。

実は、各要素が相異なる数列 \(a=(a_1,a_2,\dots,a_k)\)門松的 であることは、以下が成り立つことと同値です。

  • \(a_1 < a_2\) かつ \(a_{k-1}>a_k\)

証明

  • \(a_{i-1} < a_i\) かつ \(a_i > a_{i+1}\) なる部分を \(x\) 折り返し
  • \(a_{i-1} > a_i\) かつ \(a_i < a_{i+1}\) なる部分を \(y\) 折り返し

と呼ぶことにします。

ある \(k\) について \(a_k < a_{k+1}\) である場合、次に折り返す場合は必ず \(x\) 折り返しになります。
また、ある \(k\) について \(a_k > a_{k+1}\) である場合、次に折り返す場合は必ず \(y\) 折り返しになります。

\(a_1 < a_2\) である場合、最初に \(x\) 折り返し、次に \(y\) 折り返し、次に \(x\) 折り返し、… と続くことになります。 \(x>y\) となるためには、最後に \(x\) 折り返しを受けた状態で \(y\) 折り返しを受けていない状態、つまり \(a_{k-1} > a_k\) であることが必要であり、これで十分です。
\(a_1 > a_2\) である場合、最初に \(y\) 折り返し、次に \(x\) 折り返し、次に \(y\) 折り返し、… と続くことになります。この場合、 \(x>y\) となることはありません。

以上より、冒頭の事実の証明が完了しました。

あとは、門松的である部分列を数え上げればよいです。

\(a_2,a_{k-1}\) を固定して考えます。
\(a_1 < a_2\) なる \(a_1\) の取り方が \(p\) 通り、 \(a_{k-1} > a_k\) なる \(a_k\) の取り方が \(q\) 通りあるとします。
\(k=3\) ( 即ち \(k-1=2\) ) のケースを考慮するために、 \(pq\) を答えに加算します。
\(k \ge 4\) のケースについて、 \(a_2\) が元の列での \(P_l\)\(a_{k-1}\) が元の列での \(P_r\) であるなら、 \(P_{l+1},P_{l+2},\dots,P_{r-1}\) は使うかどうかを自由に選択できます。なので、答えに \(pq \times 2^{r-l-1}\) を加算します。
\(a_1 < a_2\) なる \(a_1\) の取り方、 \(a_{k-1} > a_k\) なる \(a_k\) の取り方は segtree 、平衡二分探索木などのデータ構造を活用して数え上げることができます。
\(2^{r-l-1}\) が付いた加算たちも、累積和の要領で足し合わせることができます。

この解法の時間計算量は \(O(N \log N)\) です。

実装例 (C++):

#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
#include<ext/pb_ds/tag_and_trait.hpp>
using namespace __gnu_pbds;

using namespace std;
using ll=long long;

const ll mod=998244353;

int main(){
  ll n;
  cin >> n;
  vector<ll> p(n);
  for(auto &nx : p){cin >> nx;}

  vector<ll> beg(n),fin(n);
  tree<ll,null_type,less<ll>,rb_tree_tag,tree_order_statistics_node_update> tr;
  for(ll i=0;i<n;i++){
    beg[i]=tr.order_of_key(p[i]);
    tr.insert(p[i]);
  }
  tr.clear();
  for(ll i=n-1;i>=0;i--){
    fin[i]=tr.order_of_key(p[i]);
    tr.insert(p[i]);
  }

  ll res=0;
  for(ll i=0;i<n;i++){
    res+=beg[i]*fin[i];
    res%=mod;
  }
  ll s=0;
  for(ll i=0;i<n;i++){
    res+=s*fin[i];
    res%=mod;
    s=(2*s+beg[i])%mod;
  }
  cout << res << "\n";
  return 0;
}

投稿日時:
最終更新: