Official

F - Double Sum 2 Editorial by en_translator


First, define \(d_k\) as the sum of \(A_i+A_j\) over all \((i,j)\) such that:

  • \(1\le i\le j\le N\); and
  • \(A_i+A_j\) is a multiple of \(2^k\).

Then, the sum of \(A_i+A_j\) over all \((i,j)\) such that

  • \(1\le i\le j\le N\); and
  • \(A_i+A_j\) is divisible by exactly \(k\) times

can be represented as \(d_k-d_{k+1}\), so the sought answer is \(\displaystyle \sum_{k\geq 0}\frac{d_k-d_{k+1}}{2^k}\). Since \(A_i+A_j\le 2\times 10^7<2^{25}\), it is sufficient to consider \(k\) within the range \(0\le k\le 25\). Hereinafter, we will consider how to find \(d_k\) for a given \(k\).

The condition that \(A_i+A_j\) is a multiple of \(2^k\) is equivalent to that \(A_j\equiv -A_i\mod 2^k\). For a fixed \(j\), let \(C_j\) and \(S_j\) be the number and sum, respectively, of the elements \(A_j\) such that:

  • \(1\le i\le j\); and
  • \(A_j\equiv -A_i\mod 2^k\),

then we have \(\displaystyle d_k=\sum_{j=1}^N(C_jA_j+S_j)\). These \(C_j\) and \(S_j\) can be found fast using a dictionary (map) with its key being the value \(k=\left(-A_i\ \text{mod}\ 2^k\right)\), and its value being the number and sum of the elements \(A_i\) with the key \(k\).

The problem can be solved by appropriately implementing it. The complexity is \(O(N\log \max A)\).

Sample code (Python3)

from collections import defaultdict
import sys
input = sys.stdin.readline
n = int(input())
a = list(map(int, input().split()))
K = 25
cur = [0] * (K + 1)
for k in range(K):
    kk = 1 << k
    s = defaultdict(lambda: (0, 0))
    for i in a:
        c1, c2 = s[-i % kk]
        s[-i % kk] = c1 + 1, c2 + i
        c1, c2 = s[i % kk]
        cur[k] += c2 + i * c1
ans = 0
for i in range(K):
    assert cur[i] % (1 << i) == 0
    ans += (cur[i] - cur[i + 1]) >> i
print(ans)

posted:
last update: