树上点分治详解【入门向】

       本蒟蒻想要给大家写一篇尽可能详细的树上点分治的文章,以便刚入门的各位能够理解树上点分治,就不用像我一样在网上看了十几篇大佬的文章后还很蒙逼了(我太菜了QAQ),那么,我们就进入正题吧!

       首先安利一波Guess_Ha大佬的树上点分治

       树上点分治主要解决有关树上路径统计的问题,我们这里以洛谷上的P3806 【模板】点分治1作为例题。

       题目大意:有一棵树,m个询问,每次询问树上有没有两个点之间的路径长度为k。

       根据树上点分治这个名字,那么我们自然就要分治啦,怎么分?题目只给了一棵树,于是我们就分这颗树,树上点分治的流程就是:每次选当前树上的一个点作为根,将这个点与它的子树分开,然后对子树也进行这样的操作(分),最后利用子树来统计答案(治)。就像这样:

       那么,第一个重点来了!如何拆这颗树,这是个很严肃的问题,分的方法决定了分治的效率,不妨设想一下,假如树退化成一条链(有n个节点),我们每次取链首作为根,那么这样分治要分n层,因为点分治在每一层所用的时间大约为O(n),那这样分治的时间复杂度就是O(n^2), 稳稳地超时(绝望.jpg),那么找哪个点作为根才能让时间复杂度变得更优秀呢?显然的,找树的重心是最优的,树的重心的定义是,树中所有节点中最大的子树最小的那个节点(百度百科),那么每次选树的重心作为根,拆出来的子树的大小就会尽可能平均,那分治的层数也就尽可能的小了(如果还是不理解的同学可以联系二分来思考一下,二分查找的时候是看左右端点的中间点,这样可以把序列分成平均的两份,如果是随便找一个点,那么查找的次数自然就会增多,效率也就低了),这样可以保证分治的层数大约是logn,那总的时间复杂度就是O(nlogn)了,十分的优秀(记笔记记笔记)。

       找树的重心我们只需要将整棵树过一遍即可,先贴上代码,对着代码讲方便各位理解。

void getroot(int x,int fa)//fa表示x的父亲,防止往回搜
{
	size[x]=1;mson[x]=0;//size[i]记录以i为根的子树的大小,mson[i]表示i节点的最大子树的大小 
	for(int i=first[x];i;i=e[i].next)
	{
		int y=e[i].y;
		if(v[y]||y==fa)continue;//v[y]表示当前节点是否被分治过,先别管这货,后面再讲 
		getroot(y,x);//往下继续搜索 
		size[x]+=size[y];//加上子树大小 
		if(size[y]>mson[x])mson[x]=size[y];//更新最大的子树 
	}
	if(Size-size[x]>mson[x])mson[x]=Size-size[x];
	//Size表示当前这整棵树的大小,那么Size-size[x]就表示不在x的子树内的节点数量,下面详解 
	if(ms>mson[x])ms=mson[x],root=x;//ms表示树的重心的最大子树的大小(相当于mson[root]),这一步是用来更新树的重心的,root用来记录重心 
}

       再解释一下这一句

if(Size-size[x]>mson[x])mson[x]=Size-size[x];

       先画一幅图, 方便理解。

       比如说当前到了红色节点处(也就是x为红色节点),那么size[x]就是绿圈里的点,Size-size[x]就是蓝圈里的点。

       如果以红色节点为根的话,其实蓝色圈里的节点也是红色节点的一棵子树,所以要算出这棵蓝圈里面的子树的大小,用它来更新一下mson(别忘记mson的定义呀)。

       那。。。我们继续?

       现在我们知道了如何拆树,也就是分的过程差不多解决了,那接下来的问题就是如何治!

       对于每一次找到的重心,我们统计与它有关的路径,也就是经过它的路径,怎么统计呢?先算出一个dis数组,dis[i]表示i这个点到目标点的距离(也可以说是目标点到这个点的距离),那么我们很容易就可以在O(n)的时间内求出dis数组。代码如下:

void getdis(int x,int fa,int z)//fa表示x的父亲,z表示x到目标点的距离
{
    dis[++t]=z;//这里写dis[++t]=z而不是dis[x]=z是因为后面我们会发现每一个t对应哪一个x并不重要,也就是说我们只需要知道每一个点到目标点的距离,而不需要知道那个点是谁,看到后面就会明白的啦
    for(int i=first[x];i;i=e[i].next)
    {
        int y=e[i].y;
        if(y==fa||v[y])continue;//这里加个v[y]是因为不能往被分治过的点走,后面会讲的啦
        getdis(y,x,z+e[i].z);
    }
}

       求出dis数组后,我们就可以将两两组合,得到一条条路径的长度,比如说,我们设目标点为重心,有一个点A到重心的距离为x,又有一个点B到重心的距离为y,那么A到B的距离就是x+y了!(树上两点间的路径唯一)

       但是!问题又来了,我们再看到下面这一幅图:

       我们设点1为重心。

       那么

       1—2的路径就是1—2,长度为1

       1—3的路径就是1—2—3,长度为2

       1—4的路径就是1—2—4,长度为2

       1—5的路径就是1—5,长度为1

       假如将1—2和1—5两条路径合并,得到2—1—5,也就是2到5的路径,长度为(1—2的长度)+(1—5的长度)=2,和上面说的一样。

       但问题就在于我如果把1—2—3和1—2—4这两条路径合并,得到的就是3—2—1—2—4这么一条路径,长度为4!什么鬼 ,这和想象中的不一样啊!3—4的距离不应该是2吗?我们发现,这条奇怪的路径走了1—2这条边两次,然而事实上1—2这条边是一次都不用走的,那么这种问题怎么解决呢?

       仔细想想,可以发现,假如我们将两个在同一棵子树内的点的dis组合起来,就会发生如上问题,就像上述例子,假如组合3、4,因为3、4都是在重心1的同一棵子树内,所以,我们称这种组合为不合法组合。那么我们一开始将dis数组两两组合得到的所有组合称为所有组合,那答案就显而易见了!合法组合=所有组合-不合法组合,历经千辛万苦,终于得到答案啦,那现在就看看代码吧!

void fenzhi(int x)
{
    v[x]=true;//代码保证每次进来的x都必定是当前这棵树的重心,我们将v[x]标记为true,表示x点被分治过
    solve(x,0,1);//计算这棵树以x为重心的所有组合,solve函数后面会讲
    for(int i=first[x];i;i=e[i].next)
    {
        int y=e[i].y;
        if(v[y])continue;
        solve(y,e[i].z,0);//计算不合法组合,用所有组合减去不合法组合
        ms=inf;root=0;Size=size[y];//记得要初始化
        getroot(y,0);//求出以y为根的子树
        fenzhi(root);
    }
}

接下来附上solve函数

void solve(int x,int y,int id)//x表示getdis的起始点,y表示x到目标点的距离,id表示这一次统计出来的答案是合理的还是不合理的
{
    t=0;
    getdis(x,0,y);//算出树中的点到目标点的距离
    if(id==1)//累计答案 
    {
        for(int i=1;i<t;i++)
        for(int j=i+1;j<=t;j++)
        sum[dis[i]+dis[j]]++;//sum[i]表示有多少条长度为i的路径,结合这一段就可以理解求dis数组时的dis[++t]=z了
    }
    else//去掉不合理答案 
    {
        for(int i=1;i<t;i++)
        for(int j=i+1;j<=t;j++)
        sum[dis[i]+dis[j]]--;
    }
}

       先别走!再等我多唠叨几句!

       我觉得我还需要对代码中的一个地方讲解一下

       那就是fenzhi函数中去除不合理组合的那一句话:solve(y,e[i].z,0);

       我们仔细思考一下这句话,solve(y,e[i].z,0),进入到solve函数时,首先做什么?getdis!没错,我们会以x点(也就是y的父亲)作为目标点计算出以y为根的子树中所有点的dis。等等!为什么不是以y作为目标点?仔细看看,solve里面y后面的那个参数,是e[i].z,而不是0,这意味着dis[y]=e[i].z。那么以y为根的子树中的所有点的dis值其实就等于他们到y点的距离+e[i].z,又因为e[i].z是y点到x点的距离,所以,其实这次计算出来的dis值就是以y为根的子树中所有点到x点的距离!

完整代码在此(只有99行100行都不到呢!很精简有木有)

#include <cstdio>
#include <cstring>
#define inf 999999999

int n,m,len=0,Size;
struct node{int x,y,z,next;};
node e[20010];
int first[10010];
int root,ms,size[10010],mson[10010],sum[10000010];
bool v[10010];
void buildroad(int x,int y,int z)
{
    len++;
    e[len].x=x;
    e[len].y=y;
    e[len].z=z;
    e[len].next=first[x];
    first[x]=len;
}
void getroot(int x,int fa)//fa表示x的父亲,防止往回搜
{
	size[x]=1;mson[x]=0;//size[i]记录以i为根的子树的大小,mson[i]表示i节点的最大子树的大小 
	for(int i=first[x];i;i=e[i].next)
	{
		int y=e[i].y;
		if(v[y]||y==fa)continue;//v[y]表示当前节点是否被分治过,先别管这货,后面再讲 
		getroot(y,x);//往下继续搜索 
		size[x]+=size[y];//加上子树大小 
		if(size[y]>mson[x])mson[x]=size[y];//更新最大的子树 
	}
	if(Size-size[x]>mson[x])mson[x]=Size-size[x];
	//Size表示当前这整棵树的大小,那么Size-size[x]就表示不在x的子树内的节点数量,下面详解 
	if(ms>mson[x])ms=mson[x],root=x;//ms表示树的重心的最大子树的大小(相当于mson[root]),这一步是用来更新树的重心的,root用来记录重心 
}
int t;
int dis[10010];
void getdis(int x,int fa,int z)//fa表示x的父亲,z表示x到目标点的距离
{
    dis[++t]=z;//这里写dis[++t]=z而不是dis[x]=z是因为后面我们会发现每一个t对应哪一个x并不重要,也就是说我们只需要知道每一个点到目标点的距离,而不需要知道那个点是谁,看到后面就会明白的啦
    for(int i=first[x];i;i=e[i].next)
    {
        int y=e[i].y;
        if(y==fa||v[y])continue;//这里加个v[y]是因为不能往被分治过的点走,后面会讲的啦
        getdis(y,x,z+e[i].z);
    }
}
void solve(int x,int y,int id)//x表示getdis的起始点,y表示x到目标点的距离,id表示这一次统计出来的答案是合理的还是不合理的
{
    t=0;
    getdis(x,0,y);//算出树中的点到目标点的距离
    if(id==1)//累计答案 
    {
        for(int i=1;i<t;i++)
        for(int j=i+1;j<=t;j++)
        sum[dis[i]+dis[j]]++;//sum[i]表示有多少条长度为i的路径,结合这一段就可以理解求dis数组时的dis[++t]=z了
    }
    else//去掉不合理答案 
    {
        for(int i=1;i<t;i++)
        for(int j=i+1;j<=t;j++)
        sum[dis[i]+dis[j]]--;
    }
}
void fenzhi(int x)
{
    v[x]=true;//代码保证每次进来的x都必定是当前这棵树的重心,我们将v[x]标记为true,表示x点被分治过
    solve(x,0,1);//计算这棵树以x为重心的所有组合,solve函数后面会讲
    for(int i=first[x];i;i=e[i].next)
    {
        int y=e[i].y;
        if(v[y])continue;
        solve(y,e[i].z,0);//计算不合法组合,用所有组合减去不合法组合
        ms=inf;root=0;Size=size[y];//记得要初始化
        getroot(y,0);//求出以y为根的子树
        fenzhi(root);
    }
}

int main()
{
    scanf("%d %d",&n,&m);
    for(int i=1;i<n;i++)
    {
        int x,y,z;
        scanf("%d %d %d",&x,&y,&z);
        buildroad(x,y,z);
        buildroad(y,x,z);
    }
    root=0;ms=inf;Size=n;
    getroot(1,0);
    fenzhi(root);
    for(int i=1;i<=m;i++)
    {
        int x;
        scanf("%d",&x);
        if(sum[x])printf("AYE\n");
        else printf("NAY\n");
    }
}

感谢各位的阅读qwq

猜你喜欢

转载自blog.csdn.net/a_forever_dream/article/details/81778649