Official

Ex - Perfect Binary Tree Editorial by en_translator


Here is an important fact:

  • An induced subgraph is never a perfect binary tree of depth \(20\) or more.
    • Reason: if the induced subgraph is a perfect tree of depth \(20\) or more, then it should have at least \(2^{20}-1 = 1048575\) vertices, but by the constraints the original graph has at most \(3 \times 10^5\) vertices, which is not enough.

That’s why we can say:

  • We may ignore vertices of depth \(20\) or more.

Thus we can observe that the following solution is valid: for each of \(N\) vertices, update something on each of its ancestors (at most \(20\) times) when it is added.

Now let us solve the original problem. We manage the following values:

  • \(dp[\) Vertex \(v\) \(][\) depth \(d\) \(]\) … The number of induced perfect binary tree of depth \(d\) (with \(2^{d+1}-1\)) when rooted at Vertex \(v\)
  • \(csum[\) Vertex \(v\) \(][\) depth \(d\) \(]\) … The sum of \(dp[w][d]\) over all children \(v\) of \(w\)

Here, it is convenient to treat that Vertex \(1\) has Vertex \(0\) as its parent.

Starting from Vertex \(1\), consider finding the answer by Vertex \(2\) and so on in order.
Note that the answer is the sum of \(csum[0]\) for each step.
When a Vertex \(v\) is added, do the following:

  • If the depth of \(v\) is \(20\) or larger, then there is nothing to update.
  • Otherwise, for \(v\) and each of its ancestors, update \(dp\) and \(csum\):
    • Specifically, count the number of perfect binary trees containing \(v\) and sum them up.
    • When incrementing the number of perfect trees, store the number of perfect binary trees on the \(v\) side while multiplying them by the number of perfect binary tree obtained from the other side (which we can find from \(dp\) and \(csum\)).

Sample code (C++):

#include<bits/stdc++.h>
#define mod 998244353
 
using namespace std;
 
int main(){
  int n;
  cin >> n;
  vector<int> p(n+1);
  for(int i=2;i<=n;i++){cin >> p[i];}
 
  vector<int> d(n+1,0);
  vector<vector<long long>> dp(n+1,vector<long long>(20,0));
  vector<vector<long long>> csum(n+1,vector<long long>(20,0));
 
  p[1]=0;
  dp[1][0]=1;
  csum[0][0]=1;
  cout << "1\n";
 
  for(int i=2;i<=n;i++){
    d[i]=d[p[i]]+1;
 
    if(d[i]<20){
      long long val=1;
      int v,curd;
      v=i;curd=0;
 
      while(1){
        int par=p[v];
        long long nval=val*(mod+csum[par][curd]-dp[v][curd])%mod;
 
        csum[par][curd]+=val;csum[par][curd]%=mod;
        dp[v][curd]+=val;dp[v][curd]%=mod;
        if(par==0){break;}
        v=par;
        curd++;
        val=nval;
      }
    }
    
    long long res=0;
    for(int i=0;i<20;i++){
      res+=csum[0][i];res%=mod;
    }
    cout << res << "\n";
  }
  return 0;
}

posted:
last update: