Official

D - Bitmask Editorial by en_translator


We assume that integers are in binary, calling the least significant digit “the first digit”, second least significant one “the second digit”, and so on. Also for non-negative integers \(n\) and \(i\), the \(i\)-th digit of \(n\) is denoted by \(n[i]\). Moreover, we identify the characters 0 and 1 with the digits \(0\) and \(1\).

Take an integer \(t\). Suppose that \(t \leq N\), and that \(t\) and \(N\) differs at the \(i\)-th digit for the first time when you inspect the most significant digit to the least. (If \(t=N\), let \(i=-1\).) In other words, we assume that

  • if \(i\neq -1\), then \(t[i]=0\) and \(N[i]=1\);
  • for all \(j\ (j >i)\), we have \(t[j]=N[j]\).

Then \(t\in T\) if and only if all of the following conditions are satisfied:

  • if \(i\neq -1\), then (\(s[i]=\)? or \(s[i]=\)0) and \(N[i]=1\);
  • for all \(j\ (j >i)\), we have \(s[j]=\)? or \(s[j]=N[j]\);
  • for all \(j\ (j <i)\), we have \(s[j]=\)? or \(s[j]=t[j]\).

Let \(i^*\) be the maximum \(i\) satisfying the first two conditions (which depends only on \(S\) and \(N\)). The larger \(i\) is, the larger is \(t\), so the answer is the maximum \(t\) satisfying \(i=i^*\) and the third condition. Thus, the following algorithm solves this problem.

  1. Find \(i^*\). If no \(i\) satisfies the two conditions, the answer is -1.
  2. For \(j\ (j < i^*)\), replace \(s[j]\) with 1 if \(s[j]=\)?.
  3. For \(j\ (j > i^*)\), replace \(s[j]\) with \(N [j]\) if \(s[j]=\)?.
  4. Print the inter represented by the binary string \(S\).

Sample code (C++):

#include<bits/stdc++.h>

using namespace std;

using ll = long long;

int main() {
    string s;
    ll n;
    cin >> s >> n;
    
    reverse(s.begin(), s.end());
    while (s.size() < 60) s.push_back('0');
    int lb = -1;
    for (int i = 0; i < 60; i++) {
        if (s[i] != '?' and s[i] - '0' != (n >> i & 1)) lb = i;
    }
    if (lb == -1) {
        cout << n << endl;
        return 0;
    }
    for (int i = lb; i < 60; i++) {
        if (s[i] == '1' or !(n >> i & 1)) continue;
        s[i] = '0';
        for (int j = 0; j < i; j++) {
            if (s[j] == '?') s[j] = '1';
        }
        for (int j = i + 1; j < 60; j++) {
            s[j] = '0' + (n >> j & 1);
        }
        ll ans = 0;
        for (int j = 59; j >= 0; j--) {
            ans *= 2;
            ans += s[j] - '0';
        }
        cout << ans << endl;
        return 0;
    }
    cout << -1 << endl;
}

posted:
last update: