Description
机房断网了!xj轻而易举地撬开了中心机房的锁,拉着zwl走了进去。他们发现中心主机爆炸了。
中心主机爆炸后分裂成了 n 块碎片,但碎片仍然互相连接,形成一个树的结构。每个碎片有一个状态值0或1 。zwl找了一下规律,发现只有所有碎片的状态值相同的时候,主机才能够修复。
xj碰了碰其中一个碎片 x ,发现对于满足 x 到 v 的路径上所有碎片的状态值与 x 的状态值相同 的那些碎片 v 状态值都取反(0变1,1变0)了!
现在他们要尝试修复这个网络,最少需要多少次触碰呢?
Input
碎片从 1 到 n 编号。
第一行一个整数 n ,第二行 n 个数 0 或 1, 第 i 个数表示 i 号碎片的状态值。
接下来 n−1 行,每行两个数 x,y 表示 x 与 y 碎片中有连接。
Output
一行一个数,表示最少需要的碰撞次数。
Sample Input
11
0 0 0 1 1 0 1 0 0 1 1
1 2
1 3
2 4
2 5
5 6
5 7
3 8
3 9
3 10
9 11
Sample Output
2
HINT
样例解释:首先触碰三号碎片,再触碰六号碎片,这样所有碎片的状态值都会变为1 ,共触碰两次。
数据范围如下:
对于 20% 的数据,n≤10
对于 100% 的数据,n≤5×105
思路
题意为一次触碰表示把这个点的同色连通块取反,问需要多少次触碰使得整个树是同颜色的。
暴力枚举所有触碰方案即可得到20分。
对于更高的目标,首先可以发现的是一个连通块可以缩成一个点。
接下来我们要说明,这题的答案是 ⌊d+12⌋ ,其中 d 是缩点后树的直径。
对树缩点后,整棵树变成了一个黑白(姑且用黑白代表01)相间的分层树,那么我们任意触碰一个点,会使得这个点和上下的颜色变成相同,再次缩点,之后树的直径最多比之前减少2 。因此一次操作最多使树的直径减2,那么 ⌊d+12⌋ 就是答案的下界,接下来说明这个下界一定是可以取到的。
找到树上的一个点,使得这个点到其他所有点的距离都不超过 ⌊d+12⌋ 。这个点是一定可以找到的,否则就会与直径长度为 d 矛盾,⌊d+12⌋+⌊d+12⌋+1>d 。
我们不断地触碰这个点并缩点,那么树的直径每次都会减小2,直到只剩一个点或两个点,最后最多再触碰一次,所以答案就是 ⌊d+12⌋ 了。
复杂度为 O(n) 。
代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
struct edge
{
int to,nxt;
} e[1000010],eg[1000010];
int n,a[500010],head[500010]= {0};
int tot=0,col[500010]= {0};
int cnt=0,dep[500010],ghead[500010]= {0};
void dfs(int u,int fa)
{
col[u]=tot;
for(int i=head[u]; i; i=e[i].nxt)
{
int v=e[i].to;
if(v==fa)
continue;
if(a[v]==a[u])
dfs(v,u);
}
return;
}
void dfs(int u,int fa,int d)
{
dep[u]=d;
for(int i=ghead[u]; i; i=eg[i].nxt)
{
int v=eg[i].to;
if(dep[v]) continue;
dfs(v,u,d+1);
}
return;
}
int main()
{
scanf("%d",&n);
for(int i=1; i<=n; i++)
scanf("%d",&a[i]);
for(int i=1; i<n; i++)
{
int u,v;
scanf("%d%d",&u,&v);
e[i]=(edge)
{
v,head[u]
};
head[u]=i;
e[n+i]=(edge)
{
u,head[v]
};
head[v]=n+i;
}
for(int i=1; i<=n; i++)
if(!col[i])
tot++,dfs(i,0);
for(int u=1; u<=n; u++)
for(int i=head[u]; i; i=e[i].nxt)
{
int v=e[i].to;
if(col[u]==col[v])
continue;
eg[++cnt]=(edge)
{
col[v],ghead[col[u]]
};
ghead[col[u]]=cnt;
eg[++cnt]=(edge)
{
col[u],ghead[col[v]]
};
ghead[col[v]]=cnt;
}
memset(dep,0,sizeof(dep));
dfs(1,0,1);
int id=1;
for(int i=2; i<=tot; i++)
if(dep[i]>dep[id])
id=i;
memset(dep,0,sizeof(dep));
dfs(id,0,1);
int maxn=0;
for(int i=1; i<=tot; i++)
maxn=max(maxn,dep[i]);
printf("%d\n",maxn/2);
return 0;
}