公式

D - Kadomatsu Subsequence 解説 by physics0523


原案: admin

いったん次の条件だけを取り出して処理しましょう。

  • \(1 \le i,j,k \le N\)
  • \(A_i : A_j : A_k = 7:5:3\)

\(A\) の各要素が正整数であることから、二番目の条件は次のように言い換えられます。

  • ある正整数 \(t\) を用いて、 \(A_i=7t,A_j=5t,A_k=3t\) と書ける。

また、 \(A_i,A_j,A_k\) が相異なることも分かります。

この事実を使うと、次のような解法が出せます。

  • \(A\) の各要素 \(A_l\) に対して、次の事実が言える。
    • \(A_l\)\(7\) の倍数である時、 \(t=A_l/7\) の場合で \(i=l\) とする権利がある。
    • \(A_l\)\(5\) の倍数である時、 \(t=A_l/5\) の場合で \(j=l\) とする権利がある。
    • \(A_l\)\(3\) の倍数である時、 \(t=A_l/3\) の場合で \(k=l\) とする権利がある。
    • 正当な \((i,j,k)\) の組に対して \(A_i,A_j,A_k\) は相異なることから、同一の \(t\) に対して \(i\) の候補の数、 \(j\) の候補の数、 \(k\) の候補の数をそのまま掛け合わせれば数え上げが完了する。

これで、以下の条件がない場合の数え上げは完了しました。

  • \(\min(i,j,k)=j\) または \(\max(i,j,k)=j\)

この条件をどのように課せばよいでしょうか?
扱いやすいよう、先ほどの数え上げをした上で以下に該当するものを引き去ることにします。

  • \(i,j,k\) を値の昇順に並び替えた場合、 \(j\)\(2\) 番目に来る。

\(t\) に対して次のようにすると、このケースも数え上げることができます。

  • 全ての \(j\) の候補について、その \(j\) の候補より小さい \(i\) の個数とその \(j\) の候補より大きい \(k\) の個数とを掛け合わせることで、 \(i<j<k\) のケースを数え上げることができる。
  • 全ての \(j\) の候補について、その \(j\) の候補より小さい \(k\) の個数とその \(j\) の候補より大きい \(i\) の個数とを掛け合わせることで、 \(k<j<i\) のケースを数え上げることができる。

これは、イベントソートの要領で実装可能です。

答えが非常に大きくなり、 \(32\)bit 整数では収まらないケースがあることに注意してください。答えは粗く見積もっても \((3 \times 10^5)^3\) 以下なので、 \(64\)bit 符号付き整数には収まります。

この解法の時間計算量は \(O(N \log N)\) です。

実装例 (C++):

#include<bits/stdc++.h>

using namespace std;
using ll=long long;
using pl=pair<ll,ll>;

ll solve(vector<ll> &v7,vector<ll> &v5,vector<ll> &v3){
  ll res=ssize(v7);
  res*=ssize(v5);
  res*=ssize(v3);
  if(res==0){return 0;}

  vector<pl> vp;
  for(auto &nx : v7){vp.push_back({nx,7});}
  for(auto &nx : v5){vp.push_back({nx,5});}
  for(auto &nx : v3){vp.push_back({nx,3});}
  sort(vp.begin(),vp.end());
  ll p7=0,s7=v7.size();
  ll p3=0,s3=v3.size();
  for(auto &nx : vp){
    if(nx.second==7){
      p7++; s7--;
    }
    else if(nx.second==5){
      res-=p7*s3;
      res-=p3*s7;
    }
    else if(nx.second==3){
      p3++; s3--;
    }
  }
  return res;
}

int main(){
  ll n;
  cin >> n;
  map<pl,vector<ll>> mp;
  set<ll> tst;
  for(ll i=0;i<n;i++){
    ll a;
    cin >> a;
    if(a%7==0){
      mp[{a/7,7}].push_back(i);
      tst.insert(a/7);
    }
    if(a%5==0){mp[{a/5,5}].push_back(i);}
    if(a%3==0){mp[{a/3,3}].push_back(i);}
  }
  ll res=0;
  for(auto &nx : tst){
    res+=solve(mp[{nx,7}],mp[{nx,5}],mp[{nx,3}]);
  }
  cout << res << "\n";
  return 0;
}

投稿日時:
最終更新: