CodeForces - 1118F1 Tree Cutting (Easy Version) 树形dp

You are given an undirected tree of nn vertices.

Some vertices are colored blue, some are colored red and some are uncolored. It is guaranteed that the tree contains at least one red vertex and at least one blue vertex.

You choose an edge and remove it from the tree. Tree falls apart into two connected components. Let's call an edge nice if neither of the resulting components contain vertices of both red and blue colors.

How many nice edges are there in the given tree?

Input

The first line contains a single integer nn (2≤n≤3⋅1052≤n≤3⋅105) — the number of vertices in the tree.

The second line contains nn integers a1,a2,…,ana1,a2,…,an (0≤ai≤20≤ai≤2) — the colors of the vertices. ai=1ai=1 means that vertex ii is colored red, ai=2ai=2 means that vertex ii is colored blue and ai=0ai=0 means that vertex ii is uncolored.

The ii-th of the next n−1n−1 lines contains two integers vivi and uiui (1≤vi,ui≤n1≤vi,ui≤n, vi≠uivi≠ui) — the edges of the tree. It is guaranteed that the given edges form a tree. It is guaranteed that the tree contains at least one red vertex and at least one blue vertex.

Output

Print a single integer — the number of nice edges in the given tree.

Examples

Input

5
2 0 0 1 2
1 2
2 3
2 4
2 5

Output

1

Input

5
1 0 0 0 2
1 2
2 3
3 4
4 5

Output

4

Input

3
1 1 2
2 3
1 3

Output

0

题意:给你一棵树,分别有蓝色、红色,无色节点,去掉一条边,每一部分只有一种带颜色的点,求去掉边的种数

题解:树形dp,先计算下每个节点包括子代的两种颜色节点的总数,然后看每个字节点,是不是只有一种颜色且另一种颜色为0

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int n;
vector<int> v[300010];
int col[300010];
int red[300010],blue[300010];
int ans;
int re,bl;
void dfs(int u,int fa)
{
	if(col[u]==1) red[u]=1,re++;
	if(col[u]==2) blue[u]=1,bl++;
	for(int i=0;i<v[u].size();i++)
	{
		int to=v[u][i];
		if(to==fa) continue;
		dfs(to,u);
		red[u]+=red[to];
		blue[u]+=blue[to];
	}
}
void dfs1(int u,int fa)
{
	for(int i=0;i<v[u].size();i++)
	{
		int to=v[u][i];
		if(to==fa) continue;
		if(red[to]==re&&blue[to]==0 || blue[to]==bl&&red[to]==0)
			ans++;
		dfs1(to,u);
		
	}
}
int main()
{
	int x,y;
	cin>>n;
	for(int i=1;i<=n;i++)
	{
		cin>>col[i];
	}
	for(int i=1;i<n;i++)
	{
		cin>>x>>y;
		v[x].push_back(y);
		v[y].push_back(x);
	}
	dfs(1,0);
	dfs1(1,0);
	cout<<ans<<endl;
	return 0;
}

猜你喜欢

转载自blog.csdn.net/mmk27_word/article/details/87908927