Tree(POJ-1741)

题目描述:

Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.

输入

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.

输出

For each test case output the answer on a single line.

思路:

比较普遍的一道点分治题,考虑每一棵树,以重心为根,预处理出每个点的深度,再把每个点扔到一个数组中进行线性计算,算出满足条件的所有点对,方法可以是将其排序,用两个指针从两边往中间推着计算。

不过这时候会有小问题,会多算一种情况,就是他们的LCA不是重心的情况,这时候就需要采用容斥原理的思想,在每个重心的子树中计算一遍上述的操作(注意加上重心到根节点的距离),再在答案中对应地减去,便能得到最终答案!

代码

```c++

include

include

include

using namespace std;
bool mem1;
const int N=100005;
struct Graph{
int tot,to[N<<1],nxt[N<<1],len[N<<1],head[N];
void add(int x,int y,int z){tot++;to[tot]=y;nxt[tot]=head[x];len[tot]=z;head[x]=tot;}
void clear(){tot=0;memset(head,-1,sizeof(head));}
}G;
bool vis[N];
int ans,sz[N],mx[N],t_sz,center;
int arr[N],dep[N];
int n,k;
bool mem2;
void make_dep(int x,int f){
arr[++arr[0]]=dep[x];
for(int i=G.head[x];i!=-1;i=G.nxt[i]){
int v=G.to[i];
if(v==f||vis[v])continue;
dep[v]=dep[x]+G.len[i];
make_dep(v,x);
}
}
void get_center(int x,int f){
sz[x]=1,mx[x]=0;
for(int i=G.head[x];i!=-1;i=G.nxt[i]){
int v=G.to[i];
if(v==f||vis[v])continue;
get_center(v,x);
sz[x]+=sz[v];
mx[x]=max(mx[x],sz[v]);
}
mx[x]=max(mx[x],t_sz-sz[x]);
if(!center||mx[x]<mx[center])center=x;
}
int calc(int x,int dis){
dep[x]=dis,arr[0]=0;
make_dep(x,0);
sort(arr+1,arr+arr[0]+1);
int j=arr[0],ret=0;
for(int i=1;i<=arr[0];i++){
while(j>i&&arr[i]+arr[j]>k)j--;
ret+=max(0,j-i);
}
return ret;
}
void solve(int x){
vis[x]=1;
ans+=calc(x,0);
for(int i=G.head[x];i!=-1;i=G.nxt[i]){
int v=G.to[i];
if(vis[v])continue;
ans-=calc(v,G.len[i]);
center=0,t_sz=sz[v];
get_center(v,x);
solve(center);
}
}
int main(){
while(scanf("%d%d",&n,&k)==2){
if(!n&&!k)break;
G.clear();
memset(vis,0,sizeof vis);
for(int i=1;i<n;i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
G.add(x,y,z),G.add(y,x,z);
}
center=0,t_sz=n,ans=0;
get_center(1,0);
solve(center);
printf("%d\n",ans);
}
return 0;
}

猜你喜欢

转载自www.cnblogs.com/Heinz/p/10458910.html