0 like 0 dislike
1,552 views
| 1,552 views

0 like 0 dislike
Bitwise Tree Splitting
An unweighted undirected tree can be represented with g_nodes nodes numbered from 0 to g_nodes - 1 and g_edges = g_nodes - 1 edges where the ith edge connects the nodes numbered g_from[i] to g_to[i]. Each node numbered x is associated with a value represented by value[x].

The bitwise AND of the tree is defined as the bitwise AND of all the values of the nodes of the tree.

Given g_nodes, g_from and g_to, find the number of ways to remove an edge from the tree such that the bitwise AND of the two trees thus formed are equal.

Sample Input 01
g_nodes = 4, g_edges = 3
g_from = [0, 1, 0], g_to = [1, 3, 2]
value[] size = 4
value = [3, 7, 3, 3]
Sample Output 01
3
Sample Input 02
g_nodes = 7, g_edges = 6
g_from = [0, 1, 2, 2, 4, 4], g_to = [1, 2, 3, 4, 5, 6]
value[] size = 7
value = [9, 9, 11, 9, 9, 15, 15]
Sample Output 02
4
My Solution
I didn't have time to submit this solution

#include<bits/stdc++.h>
using namespace std;

#define int long long

void sum(vector<int>& a, vector<int> const& b) {
for(int i = 0; i < (int)a.size(); i++) {
a[i] += b[i];
}
}

void sub(vector<int>& a, vector<int> const& b) {
for(int i = 0; i < (int)a.size(); i++) {
a[i] -= b[i];
}
}

void add(int num, vector<int>& a) {
for(int i = 0; i < 32; i++) {
if((1LL << i) & num) a[i]++;
}
}

bool same(vector<int> const& a, vector<int> const& b, vector<int> const& below, int total, int node) {
bool flag = true;
for(int i = 0; i < (int)a.size(); i++) {
bool sa = false, sb = false;
sa = a[i] == below[node];
sb = b[i] == (total - below[node]);
if(sa != sb) flag = false;
}
return flag;
}

int solve(int g_nodes, vector<int> g_from, vector<int> g_to, vector<int> val) {
const int n = g_nodes, m = (int)g_from.size();
vector<int> deg(n, 0);
for(int i = 0; i < m; i++) {
deg[g_from[i]]++, deg[g_to[i]]++;
}
// 32 bit
vector<vector<int>> bits(n, vector<int>(32, 0));
queue<int> q;
int lst = -1;
for(int i = 0; i < n; i++) {
if(deg[i] == 1) {
if(lst == -1) lst = i;
else q.push(i);
}
}

vector<bool> vis(n, false);
vector<int> below(n, 0);
while(!q.empty()) {
int u = q.front(); q.pop();
below[u]++;
vis[u] = true;
for(int v: adj[u]) {
if(!vis[v]) {
sum(bits[v], bits[u]);
below[v] += below[u];
deg[v]--;
}
if(!vis[v] and (deg[v] == 1 or deg[v] == 0)) {
q.push(v);
}
}
}

vector<int> ttl = bits[lst];
int ans = 0;
for(int i = 0; i < n; i++) {
if(i != lst) {
vector<int> left = ttl;
sub(left, bits[i]);
if(same(bits[i], left, below, n, i)) ans++;
}
}
return ans;
}

signed main() {
// Sample 01
// int g_node = 4;
// vector<int> g_from = {0, 1, 0}, g_to = {1, 3, 2};
// vector<int> val = {3, 7, 3, 3};
// Expected output: 3

// Sample 02
int g_node = 7;
vector<int> g_from = {0, 1, 2, 2, 4, 4}, g_to = {1, 2, 3, 4, 5, 6};
vector<int> val = {9, 9, 11, 9, 9, 15, 15};
// Expected output: 4
cout << solve(g_node, g_from, g_to, val) << '\n';

return 0;
}
by Expert (113,040 points)