Bitwise Tree Splitting
An unweighted undirected tree can be represented with g_nodes nodes numbered from 0 to g_nodes - 1 and g_edges = g_nodes - 1 edges where the ith edge connects the nodes numbered g_from[i] to g_to[i]. Each node numbered x is associated with a value represented by value[x].
The bitwise AND of the tree is defined as the bitwise AND of all the values of the nodes of the tree.
Given g_nodes, g_from and g_to, find the number of ways to remove an edge from the tree such that the bitwise AND of the two trees thus formed are equal.
Sample Input 01
g_nodes = 4, g_edges = 3
g_from = [0, 1, 0], g_to = [1, 3, 2]
value[] size = 4
value = [3, 7, 3, 3]
Sample Output 01
3
Sample Input 02
g_nodes = 7, g_edges = 6
g_from = [0, 1, 2, 2, 4, 4], g_to = [1, 2, 3, 4, 5, 6]
value[] size = 7
value = [9, 9, 11, 9, 9, 15, 15]
Sample Output 02
4
My Solution
I didn't have time to submit this solution
#include<bits/stdc++.h>
using namespace std;
#define int long long
void sum(vector<int>& a, vector<int> const& b) {
for(int i = 0; i < (int)a.size(); i++) {
a[i] += b[i];
}
}
void sub(vector<int>& a, vector<int> const& b) {
for(int i = 0; i < (int)a.size(); i++) {
a[i] -= b[i];
}
}
void add(int num, vector<int>& a) {
for(int i = 0; i < 32; i++) {
if((1LL << i) & num) a[i]++;
}
}
bool same(vector<int> const& a, vector<int> const& b, vector<int> const& below, int total, int node) {
bool flag = true;
for(int i = 0; i < (int)a.size(); i++) {
bool sa = false, sb = false;
sa = a[i] == below[node];
sb = b[i] == (total - below[node]);
if(sa != sb) flag = false;
}
return flag;
}
int solve(int g_nodes, vector<int> g_from, vector<int> g_to, vector<int> val) {
const int n = g_nodes, m = (int)g_from.size();
vector<vector<int>> adj(n);
vector<int> deg(n, 0);
for(int i = 0; i < m; i++) {
adj[g_from[i]].push_back(g_to[i]);
adj[g_to[i]].push_back(g_from[i]);
deg[g_from[i]]++, deg[g_to[i]]++;
}
// 32 bit
vector<vector<int>> bits(n, vector<int>(32, 0));
queue<int> q;
int lst = -1;
for(int i = 0; i < n; i++) {
if(deg[i] == 1) {
if(lst == -1) lst = i;
else q.push(i);
}
}
vector<bool> vis(n, false);
vector<int> below(n, 0);
while(!q.empty()) {
int u = q.front(); q.pop();
add(val[u], bits[u]);
below[u]++;
vis[u] = true;
for(int v: adj[u]) {
if(!vis[v]) {
sum(bits[v], bits[u]);
below[v] += below[u];
deg[v]--;
}
if(!vis[v] and (deg[v] == 1 or deg[v] == 0)) {
q.push(v);
}
}
}
vector<int> ttl = bits[lst];
int ans = 0;
for(int i = 0; i < n; i++) {
if(i != lst) {
vector<int> left = ttl;
sub(left, bits[i]);
if(same(bits[i], left, below, n, i)) ans++;
}
}
return ans;
}
signed main() {
// Sample 01
// int g_node = 4;
// vector<int> g_from = {0, 1, 0}, g_to = {1, 3, 2};
// vector<int> val = {3, 7, 3, 3};
// Expected output: 3
// Sample 02
int g_node = 7;
vector<int> g_from = {0, 1, 2, 2, 4, 4}, g_to = {1, 2, 3, 4, 5, 6};
vector<int> val = {9, 9, 11, 9, 9, 15, 15};
// Expected output: 4
cout << solve(g_node, g_from, g_to, val) << '\n';
return 0;
}