Let's define a function , fun(A,B) = A + B - 2(A&B) , where '&' is a bitwise operation of AND.

Given an array of size 'N' , find the sum over all the triplets triplets fun(a[i],fun(a[j],a[k])).


in Service-based-companies
2,531 views

1 Answer

Best answer

Step - 1 : There is this popular equation in bit manipulation theory :


A + B = A xor B + 2*(A&B)


How this works ? Link :


Re-arranging the terms , you get A xor B = A + B - 2*(A&B) = fun(A,B)


Voila! This means , our question is reduced to find sum of all xor triplets in the array. Answer = XoR(a[i],a[j],a[k]) over all possible triplets in the array.


Step - 2 : In order to solve this complex problem , we should first know solution to a simpler problem , which is : find sum of all xor pairs in the array. Answer = XoR(a[i],a[j]) over all possible pairs in the array.


This is a popular concept can be found here :-


From this , we derive the way for triplets.


Hence , we solved a problem on combinatorics and bit manipulation by smart simple observations :)

Time Complexity : O(N*(log(max(A[i])))

C++ code:->
#include <bits/stdc++.h>

using namespace std;

typedef long long int ll;

ll fac[1000000+7];


ll power(ll x, ll y, ll p)


    ll res = 1;

    x = x % p;

    while (y > 0)


        if (y & 1){

            res = (res * x) % p;


        y = y >> 1;

        x = (x * x) % p;


    return res;



ll modInverse(ll n, ll p)


    return power(n, p - 2, p);




ll nCrModPFermat(ll n,ll r, ll p)



    if (r == 0)


        return 1;


    if (n < r){

        return 0;




    return (fac[n] * modInverse(fac[r], p) % p * modInverse(fac[n - r], p) % p) % p;




ll Sumofxor(ll a[], ll n)



    ll mod = 1000000000+7;

    ll answer = 0;

    for (ll k = 0; k < 32; k++)


        ll x = 0, y = 0;


        for (ll i = 0; i < n; i++)


            if (a[i] & (1 << k))









        answer += ((1 << k) % mod *

                (nCrModPFermat(x, 3, mod)+ x * nCrModPFermat(y, 2, mod))% mod) % mod;


    return answer;


int main() {



ios_base::sync_with_stdio(false) ;



    ll p = 1e9 + 7 ;

    fac[0] = 1;

    for (ll i = 1; i <= 100000+55; i++){

        fac[i] = (fac[i-1]*i) % p;


    ll n ;

    cin>>n;ll a[n];

    ll i = 0 ;







    return 0 ;

by Expert (108,110 points)
