Official

E - Complete Binary Tree Editorial by en_translator


We regard the tree as a rooted tree rooted at \(1\).

The tree given in this problem is a kind of perfect binary tree, which has the following properties, so it is often featured in problems and used for implementation of data structures (such as segment trees).

  • The height of the tree is \(\log_2N\).
  • The indices of vertex \(i\)’s parent, left child, and right child can simply be represented as \(\lfloor \frac{i}{2} \rfloor,2i, 2i+1\), respectively.

The latter property further yields the following fact:

  • The set of indices of vertices that are descendants of vertex \(i\) is the intersection of segments \([i\times 2^d, (i+1)\times 2^d)\) and \([1,N]\) (*).

Back to the original problem. Let vertex \(Y\) be the vertex whose distance from vertex \(X\) is \(K\), and \(Z\) be the lowest common ancestor of \(X\) and \(Y\). Since the height of the tree is \(\log_2N\), there are only at most \((\log_2N+1)\) candidates of \(Z\). Moreover, by the fact (*) one can count in an \(O(1)\) time, for a fixed \(Z\), the number of vertices \(Y\) whose lowest common ancestor with \(X\) is \(Z\) and distance from \(X\) is \(K\). Therefore, the problem can be solved in an \(O(\log N)\) per query.

For example, assume that \(X=8\) and \(K=5\), and fix \(Z=2\). Since the distance between vertices \(X\) and \(Z\) is \(2\), “the number of vertices whose lowest common ancestor with vertex \(8\) is \(2\) and the distance from vertex \(8\) is \(5\)” equals “the number of vertices that is a descendant of vertex \(2\) but not of vertex \(4\) and whose distance from vertex \(2\) is \(3\),” for which we can now apply fact (*). The remaining part is implementation details; we can either explicitly find the segments that satisfies the conditions, or subtract “the number of descendants of vertex \(4\) that are distant from it by \(2\)” from “the number of descendants of vertex \(2\) that are distant from it by \(3\).” The sample code below adopts the former approach.

Sample code (C++):

#include<bits/stdc++.h>

using namespace std;

using ll = long long;

int depth(ll n) {
    int d = 0;
    while (true) {
        n >>= 1;
        if (!n) break;
        ++d;
    }
    return d;
}

int main() {
    int t;
    cin >> t;
    while (t--) {
        ll n, x, k;
        cin >> n >> x >> k;
        int dn = depth(n);
        int dx = depth(x);
        ll ans = 0;
        for (int i = 0; i <= dx; i++) { // i = depth(lca(x, y))
            if (dx - i > k) continue;
            if (i + k - (dx - i) > dn) continue;
            ll l, r;
            if (i == dx) {
                l = x << k;
                r = (x + 1) << k;
            } else {
                ll z = x >> (dx - i);
                if (dx - i < k) {
                    l = z * 2 + (~x >> (dx - i - 1) & 1);
                    l <<= (k - (dx - i) - 1);
                    r = l + (1LL << (k - (dx - i) - 1));
                } else {
                    l = z, r = z + 1;
                }
            }
            if (l > n) continue;
            r = min(r, n + 1);
            ans += r - l;
        }
        cout << ans << '\n';
    }
}

posted:
last update: