#include using namespace std; const int mod = (int) 1e9 + 7; int add(int x, int y) { x += y; if (x >= mod) { x -= mod; } return x; } int mul(int x, int y) { return x * (long long) y % mod; } int modpow(int a, int b) { long long res = 1; while (b > 0) { if (b & 1) { res = res * (long long) a % mod; } a = a * (long long) a % mod; b >>= 1; } return res; }; int total_brute_solve(int n, vector > g, vector color) { int ans = 0; for (int mask = 1; mask < (1 << n); mask++) { int ok = true; vector removed(n); for (int i = 0; i < n; i++) { if (mask & (1 << i)) { if (color[i]) { ok = false; break; } removed[i] = true; } } vector > C(n, vector (n)); for (int i = 0; i < n; i++) if (removed[i]) C[i][i] = true; for (int i = 0; i < n; i++) for (int to: g[i]) if (removed[i] && removed[to]) C[i][to] = true; for (int k = 0; k < n; k++) for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) if (C[i][k] && C[k][j]) C[i][j] = true; for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) if (removed[i] && removed[j] && !C[i][j]) ok = false; if (!ok) continue; vector > conn(n, vector (n)); for (int i = 0; i < n; i++) { if (!removed[i]) { conn[i][i] = true; } } for (int from = 0; from < g.size(); from++) { if (!removed[from]) { for (int to: g[from]) { if (!removed[to]) { conn[from][to] = true; } } } } for (int k = 0; k < n; k++) { for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { if (conn[i][k] && conn[k][j]) { conn[i][j] = true; } } } } int good = false; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { if (color[i] == 1 && color[j] == 2) { //cerr << i << " " << j << " " << conn[i][j] << endl; if (!conn[i][j]) { good = true; break; } } } } //cerr << "mask " << mask << " good " << good << endl; if (good) { ans++; } } return ans; } struct Tree { vector > g; vector > > dp; vector color, colored_nodes; Tree() { } Tree(int n) { g = vector > (n, vector ()); color = vector (n); colored_nodes = vector (n); } void addEdge(int from, int to) { g[from].push_back(to); g[to].push_back(from); } void addColor(int u, int col) { color[u] = col; } void DP(int from, int par, vector &necessary) { //cerr << "from " << from << " " << par << " " << endl; vector childs; for (int to: g[from]) if (to != par) { DP(to, from, necessary); childs.push_back(to); } { // take = 0. Case. for (int necessaryTaken = 0; necessaryTaken < 2; necessaryTaken++) { // Don't take anybody. int res1 = 0; if (childs.size() == 0) { if (!necessaryTaken && !necessary[from]) res1 = 1; } for (int to: childs) { res1 = add(res1, dp[to][0][necessaryTaken]); res1 = add(res1, dp[to][1][necessaryTaken]); } dp[from][0][necessaryTaken] = res1; } } { for (int necessaryTaken = 0; necessaryTaken < 2; necessaryTaken++) { // take = 1. vector candidateTakes; if (necessaryTaken && necessary[from]) candidateTakes = {0, 1}; else candidateTakes = {necessaryTaken}; int res1 = 1; if (childs.size() == 0) { if (necessaryTaken && !necessary[from]) res1 = 0; if (!necessaryTaken && necessary[from]) res1 = 0; } else { vector > f(childs.size() + 1, vector (2)); f[0][necessary[from]] = 1; //f[0][0] = 1; for (int i = 0; i < childs.size(); i++) for (int j = 0; j < 2; j++) { //till i, min(j, 1) necessary vertices selected. int to = childs[i]; f[i + 1][j] = add(f[i + 1][j], f[i][j]); for (int cand = 0; cand < 2; cand++) { int new_j = cand == 1 ? 1 : j; f[i + 1][new_j] = add(f[i + 1][new_j], mul(f[i][j], dp[to][1][cand])); } } res1 = f[childs.size()][necessaryTaken]; } dp[from][1][necessaryTaken] = res1; } } } void printTree() { for (int i = 0; i < g.size(); i++) { for (int to: g[i]) cerr << "edge " << i << " " << to << endl; } } int brute_numConnectedSets(vector necessary) { int n = g.size(); vector > C(n, vector (n)); for (int i = 0; i < n; i++) C[i][i] = true; for (int i = 0; i < n; i++) for (int to: g[i]) C[i][to] = true; for (int k = 0; k < n; k++) for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) if (C[i][k] && C[k][j]) C[i][j] = true; int ans = 0; for (int mask = 1; mask < (1 << n); mask++) { int ok = true; vector removed(n); for (int i = 0; i < n; i++) if (mask & (1 << i)) { if (color[i]) { ok = false; break; } removed[i] = true; } int is_necessary_removed = false; for (int i = 0; i < n; i++) if (removed[i] && necessary[i]) is_necessary_removed = true; int atleast_one_necessary = 0; for (int i = 0; i < n; i++) atleast_one_necessary += necessary[i]; if (atleast_one_necessary && !is_necessary_removed) ok = false; //cerr << "ok " << ok << endl; // for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) if (removed[i] && removed[j] && !C[i][j]) ok = false; // cerr << "nok " << ok << endl; if (!ok) continue; vector > conn(n, vector (n)); for (int i = 0; i < n; i++) if (removed[i]) conn[i][i] = true; for (int from = 0; from < g.size(); from++) if (removed[from]) for (int to: g[from]) if (removed[to]) conn[from][to] = true; for (int k = 0; k < n; k++) for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) if (conn[i][k] && conn[k][j]) conn[i][j] = true; int good = true; //for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) if (color[i] == 1 && color[j] == 2 && !conn[i][j]) good = true; for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) if (removed[i] && removed[j] && !conn[i][j]) good = false; //cerr << "good" << endl; //cerr << "mask " << mask << " good " << good << endl; if (good) ans++; } return ans; } int brute_solve() { return total_brute_solve(g.size(), g, color); } int numConnectedSets(vector necessary) { dp = vector > > (g.size(), vector > (2, vector (2))); DP(0, -1, necessary); /* for (int u = 0; u < g.size(); u++) for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) fprintf(stderr, "dp[%d][%d][%d] = %d\n", u, i, j, dp[u][i][j]); */ int ans = 0; for (int i = 0; i < g.size(); i++) ans = add(ans, dp[i][1][1]); return ans; } void cnt_color(int from, int par) { if (color[from]) colored_nodes[from]++; for (int to: g[from]) if (to != par) { cnt_color(to, from); colored_nodes[from] += colored_nodes[to]; } } int calc_root() { for (int i = 0; i < g.size(); i++) if (color[i] == 0) return i; assert(false); } void calc_cnt_color(int root) { cnt_color(root, -1); } void dfs_conn(int from, vector &visited, int comp) { visited[from] = comp; for (int to: g[from]) if (!visited[to] && !color[from]) dfs_conn(to, visited, comp); } void calc_bad_nodes(int from, int par, int can_take, set &bad_nodes) { if (can_take) bad_nodes.insert(from); int new_can_take = (color[from] && colored_nodes[from] == 1) ? true: can_take; for (int to: g[from]) if (to != par) { calc_bad_nodes(to, from, new_can_take, bad_nodes); } } set find_safe_to_delete_vertices() { int root = calc_root(); // printTree(); // cerr << "root " << root << endl; calc_cnt_color(root); // for (int i = 0; i < g.size(); i++) cerr << "colored_nodes " << i << " " << colored_nodes[i] << endl; set bad_nodes; calc_bad_nodes(root, -1, false, bad_nodes); // for (int x: bad_nodes) cerr << "badnode " << x << endl; return bad_nodes; } /* int solve(set delSet) { vector visited(g.size()); int comp = 1; for (int i = 0; i < g.size(); i++) { if (visited[i] || delSet.find(i) != delSet.end()) continue; dfs(i, visited, delSet, comp); comp++; } vector > indices(comp, vector ()); for (int i = 0; i < g.size(); i++) indices[visited[i]].push_back(i); vector trees(comp); for (int c = 1; c < comp; c++) trees[c] = Tree(indices[c].size()); for (int i = 0; i < g.size(); i++) for (int to: g[i]) if (i < to) { if (delSet.count(i) || delSet.count(to) || visited[i] != visited[to]) continue; int c = visited[i]; trees[c].addEdge(lower_bound(indices[c].begin(), indices[c].end(), i) - indices[c].begin(), lower_bound(indices[c].begin(), indices[c].end(), to) - indices[c].begin()); } int ans = 0; for (int c = 1; c < comp; c++) ans = add(ans, trees[c].numConnectedSets()); return ans; } */ Tree get_updated_tree(set delSet) { int idx = 0; map mp; for (int i = 0; i < g.size(); i++) if (!delSet.count(i)) mp[i] = idx++; Tree tree(idx); for (int i = 0; i < g.size(); i++) for (int to: g[i]) if (i < to && !delSet.count(i) && !delSet.count(to)) tree.addEdge(mp[i], mp[to]); for (int i = 0; i < g.size(); i++) if (!delSet.count(i)) tree.addColor(mp[i], color[i]); return tree; } vector find_connected_components() { vector visited(g.size()); int comp = 1; for (int i = 0; i < g.size(); i++) { if (visited[i] || color[i]) continue; dfs_conn(i, visited, comp); comp++; } vector > vertices(comp, vector ()); int n = g.size(); for (int i = 0; i < n; i++) if (visited[i]) vertices[visited[i]].push_back(i); vector trees; for (int c = 1; c < comp; c++) { set old_vertices, new_vertices; for (int x: vertices[c]) { new_vertices.insert(x); old_vertices.insert(x); } for (int from: vertices[c]) for (int to: g[from]) if (color[to]) { new_vertices.insert(to); } map mp; int idx = 0; for (int v: new_vertices) mp[v] = idx++; Tree tree(new_vertices.size()); vector v = vector (new_vertices.begin(), new_vertices.end()); set > edgeSet; for (int from: old_vertices) for (int to: g[from]) if (new_vertices.count(to)) { int tfrom = from, tto = to; if (tfrom > tto) swap(tfrom, tto); if (edgeSet.count({tfrom, tto})) continue; edgeSet.insert({tfrom, tto}); int nfrom = mp[tfrom], nto = mp[tto]; tree.addEdge(nfrom, nto); } for (int from: new_vertices) { int nfrom = mp[from]; tree.addColor(nfrom, color[from]); } trees.push_back(tree); } return trees; } int solveSingle() { int n = g.size(); vector deg(n); for (int from = 0; from < n; from++) deg[from] += g[from].size(); queue Q; vector is_deleted(n); for (int i = 0; i < n; i++) if (deg[i] <= 1 && !color[i]) Q.push(i); while (!Q.empty()) { int from = Q.front(); Q.pop(); is_deleted[from] = true; for (int to: g[from]) { deg[to]--; if (!is_deleted[to] && !color[to] && deg[to] <= 1) Q.push(to); } } // cerr << "is_deleted "; // for (int i = 0; i < n; i++) if (is_deleted[i]) cerr << i << " "; cerr << endl; map mp_vert; set delSet; int idx = 0; for (int i = 0; i < n; i++) if (color[i]) delSet.insert(i); else mp_vert[i] = idx++; Tree tree = get_updated_tree(delSet); // tree.printTree(); vector necessary(tree.g.size()); for (int i = 0; i < n; i++) if (!color[i] && !is_deleted[i]) necessary[mp_vert[i]] = true; /* cerr << "necessary "; for (int i = 0; i < necessary.size(); i++) if (necessary[i]) cerr << i << " "; cerr << endl; */ int t = tree.numConnectedSets(necessary); // cerr << "t " << t << endl; return t; } int solve() { int nodes = 0; for (int i = 0; i < g.size(); i++) if (color[i]) nodes++; if (nodes == g.size()) return 0; set delSet = find_safe_to_delete_vertices(); /* fprintf(stderr, "vertices that can be safely delted\n"); for (int x: delSet) { fprintf(stderr, "%d ", x); } fprintf(stderr, "----\n"); */ Tree tree = get_updated_tree(delSet); vector trees = tree.find_connected_components(); int ans = 0; for (auto it: trees) ans = add(ans, it.solveSingle()); return ans; } }; int main() { int T; scanf("%d", &T); assert(T >= 1 && T <= (int) 2e5); int tc = 0; while (T--) { tc++; int n; scanf("%d", &n); assert(n >= 1 && n <= (int) 1e5); Tree tree(n); for (int i = 0; i + 1 < n; i++) { int from, to; scanf("%d %d", &from, &to); assert(from >= 1 && from <= n); assert(to >= 1 && to <= n); assert(from != to); from--; to--; tree.addEdge(from, to); } int has_black = false, has_red = false; vector color(n); for (int i = 0; i < n; i++) { int which; scanf("%d", &which); color[i] = which; if (which != 0) { if (which == 1) has_black = true; else has_red = true; } tree.addColor(i, which); } assert(has_black && has_red); // for (int i = 0; i < n; i++) if (rand() % 3 == 1) necessary[i] = true; // for (int i = 0; i < n; i++) if (necessary[i]) cerr << "i " << i << endl; /* for (int mask = 1; mask < (1 << n); mask++) { vector necessary(n); for (int i = 0; i < n; i++) if (mask & (1 << i)) necessary[i] = true; //for (int i = 0; i < n; i++) if (necessary[i]) cerr << "necessary " << i << endl; int A = tree.numConnectedSets(necessary); int B = tree.brute_numConnectedSets(necessary); cerr << "tc " << tc << endl; cerr << "answers " << A << " " << B << endl; assert(A == B); } */ //int bans = tree.brute_solve(); int ans = tree.solve(); printf("%d\n", ans); //cerr << ans << " " << bans << endl; /* cerr << "-----" << endl; cerr << ans << endl; cerr << bans << endl; assert(ans == bans); */ } return 0; }