树的分治

我们先随意指定一个根rt,将这棵树转化成有根树

不难发现树上的路径分为两类, 经过根节点rt的路径和包含于rt的某棵子树里(不经过rt)的

对于前者, 我们用dis[u]dis[u]表示结点uu到根节点rtrt的路径长度, 则u到v的路径长即为dis[u]+dis[v]dis[u]+dis[v]

对于后者, 既然uu到vv的路径包含在rtrt的某个子树内, 那么我们就找到这棵子树的根,再对他求一次第一类路径

这样分治的思想就很明显了

就是把原来的树分成很多小的子树,并对每个子树分别求解第一类路径

点分治过程中,每一层的所有递归过程合计对每个点处理一次, 假设共递归T层,则总时间复杂度为O(T*N)O(T∗N)

然而,如果树退化成一条链, 那么递归层数就是T=nT=n,总时间复杂度为O(N^2)O(N2)

这样显然不能承受,所以我们要让树的层数经量少 这里就要找树的重心

/*
poj 1741
在一棵树上有多少对点,俩点间的最短路径小于等于m
用树的分治做
 
*/

#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
using namespace std;
const int N=10010;
int head[N<<2];
struct node
{
	int to,di,next;
}edge[N<<2];
int cc;
void addedge(int from,int to,int di)
{
	edge[cc].to=to;
	edge[cc].di=di;
	edge[cc].next=head[from];
	head[from]=cc++;
}
int n,m;
bool vis[N];
int root,f[N],son[N],sz,ans,dis[N],d[N],cnt;
void getroot(int u,int fa)//求树的重心 
{
	f[u]=0;
	son[u]=1;
	for(int i=head[u];i!=-1;i=edge[i].next){
		int to=edge[i].to;
		if(vis[to]||to==fa) continue;
		getroot(to,u);
		son[u]+=son[to];
		f[u]=max(f[u],son[to]);
	}
	f[u]=max(f[u],sz-son[u]);
	if(f[u]<f[root]){
		root=u;
	}
	
}
void getdis(int u,int fa)//求节点到根节点的距离 
{
	son[u]=1;
	d[cnt++]=dis[u];//重新定义一个序 是dfs序的思想 
	for(int i=head[u];i!=-1;i=edge[i].next){
		int to=edge[i].to;
		if(vis[to]||to==fa) continue;
		dis[to]=dis[u]+edge[i].di;
		getdis(to,u);
		son[u]+=son[to];
	}
}
int cont(int u,int mit)
{
	int res=0,l,r;
	dis[u]=mit;
	cnt=0;
	getdis(u,0);
	sort(d,d+cnt);//将点到根节点的距离排序 
	for(l=0,r=cnt-1;l<r;){
		if(d[l]+d[r]<=m) res+=(r-l++);.//代表的节点可以与r-l个节点成对 
		else r--;
	}
	return res;
	
}
void solve(int u)
{
	vis[u]=true;
	ans+=cont(u,0);//路径经过该根节点的点对数 
	for(int i=head[u];i!=-1;i=edge[i].next){
		int to=edge[i].to;
		if(vis[to]) continue;
		ans-=cont(to,edge[i].di);//to这个根节点算的对数 有的在u根结点已经算过了 
		//所以要减掉这个节点的值 
		root=0;
		f[root]=sz=son[to];//这里更新一下 to的这棵树节点个数 
		getroot(to,0);//求to子树的根节点 
		solve(root);//求to子树的点对数 
	}
}
int main()
{
	while(scanf("%d %d",&n,&m)==2){
		if(n==0&&m==0) break;
		memset(head,-1,sizeof(head));
		cc=0;
		for(int i=1;i<n;i++){
			int x,y,z;
			scanf("%d %d %d",&x,&y,&z);
			addedge(x,y,z);
			addedge(y,x,z);
		}
		memset(vis,false,sizeof(vis));
		root=0;
		f[root]=sz=n;
		ans=0;
		getroot(1,0);
		solve(root);
		printf("%d\n",ans);
	}
	return 0;
} 

猜你喜欢

转载自blog.csdn.net/CC_1012/article/details/91980177