树链剖分算法详解

学OI也有一段时间了,感觉该搞点东西了。

于是学习了树()链()剖(pou)分(粪)

当然,学习这个算法是需要先学习线段树的。不懂的还是再过一段时间吧。


如果碰到一道题,要对一颗树的两个点中的最短路径、以u为根的子树之类的东西进行修改或者查询,那么大概就是树链剖分的题了。

树链剖分就是把一颗树的节点按照新的顺序扔到一颗线段树里面,然后保证一条树链上的点在线段树中尽可能连续。

为什么是尽可能?因为在一棵树中,怎么搞也无法保证对于每一个节点,他的父亲编号都是它的-1,所以是尽可能。那么怎么尽可能呢?

有很多算法,今天提到的就是树链剖分。我们把一颗树上的所有链分成轻链重链,然后就可以对于每一段连续的重链进行线段树上的修改了。

而划分轻链和重链的依据是:对于每一个节点u,v是它的儿子,v有一个大小,就是size,代表以v为根的子树的大小。我们选取u最大的儿子为重(zhong)儿子,其余儿子为轻儿子。以连向重儿子的边为重边,剩下的边为轻边。

然后所有重边连成的链叫做重链,(并不存在轻链)比如下图,红色的链是重链(注意,对于一个叶子节点,如果连向它的是一条轻链,那么他自己就是一条重链)

这样,我们把一棵树划分成了重链和轻链,我们能保证所有重链都不重不漏的包含了所有的点。

那么这些重链有什么用?在划分重链的过程中用到的DFS,这个DFS能保证,对于每一条重链,他们的DFS序是连续的!

这样,我们就可以用线段树(或者其他数据结构)维护了!

 现在,我们把熟练剖分化成两个部分:

1、把树上的所有点划分重链,然后求出它们的DFS序,以这个顺序扔到线段树里面。

2、在线段树上进行维护。

所以,如何实现划分重链?我们需要用两个DFS,第一个DFS找到所有点的重儿子,第二个DFS将所有重儿子连成重链。

第一个DFS:size是以当前点为根的子树的大小,f是当前点的父亲,son是当前点的重儿子。

inline void getson(int u,int fa){//获取每个节点的重儿子 
    size[u]=1;
    for(int e=head[u];e;e=nxt[e])
        if(to[e]!=fa){
            depth[to[e]]=depth[u]+1;
            f[to[e]]=u; 
            getson(to[e],u);
            size[u]+=size[to[e]];//记录以每个节点为根的树的大小 
            if(!son[u] || size[son[u]]<size[to[e]])  son[u]=to[e];//判断后将这个点变为重儿子 
        }
    return ;
}

第二个DFS:

inline void getdfn(int u,int t){//连成重链,其中我们可以保证,对于每一条重链,它们的dfn值是连续的。t记录的是当前链的链首 
    top[u]=t;//top记录当前链链首 
    dfn[u]=++cnt;//记录dfn值,也是在线段树中的位置 
    link[cnt]=u;//dfn的逆运算,用于建树时的初始赋值 
    if(!son[u])  return ;//如果当前点没有重儿子,说明是这条重链的结束。 
    getdfn(son[u],t);//继续走这条重链 
    for(int e=head[u];e;e=nxt[e])//这个相当于走每一条轻链 
        if(to[e]!=son[u] && to[e]!=f[u])
            getdfn(to[e],to[e]);//重新开始走每一条重链 
    return ;
}

然后,对于线段树的建树,是独立的,我们不用考虑链的关系。(input是输入文件)

inline void build(int i,int l,int r){//平凡的建树 
    tree[i].l=l,tree[i].r=r;
    if(l==r){
        tree[i].sum=input[link[l]]%mod;//link的作用 
        return ;
    }
    int mid=(l+r)>>1;
    build(i<<1,l,mid);
    build(i<<1|1,mid+1,r);
    tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%mod;
    return ;
}

最后是修改,查询和修改很像,一起说了。

我们要把u到v路径上所有的点都+k,那么我们就把u,v中深的那个,它到它所在重边的顶端+k。

然后跳过一条轻边,重复上面的步骤,知道u,v到一条重边上。

最后把u到v,+k就可以了。

inline void treeadd(int x,int y,int z){//将题中对树的修改转化成对线段树的修改 
    int tx=top[x],ty=top[y]; 
    while(tx!=ty){//如果两个点不在一条重链上 
        if(depth[tx]<depth[ty])  swap(x,y),swap(tx,ty);//保证x的重链首元素在下方 
        add(1,dfn[tx],dfn[x],z);//从x一直修改到x所在重链的收元素,因为他们在一条重链中,所以在线段树中的位置是连续的。 
        x=f[tx];//走过一条轻链,到上面一个重链的末尾 
        tx=top[x],ty=top[y];//分别更新x、y的重链顶端,准备下一次更新 
    }
    if(depth[x]<depth[y])  swap(x,y);//现在x、y都到了一条重链上了,然后要保证x在下面。 
    add(1,dfn[y],dfn[x],z);//再只用更新他们所在的链就可以了。 
    return ;
}
inline int treesum(int x,int y){//将题中对树查询得指令改为对线段树的查询。 
    int ans=0;
    int tx=top[x],ty=top[y];
    while(tx!=ty){//这一段和修改几乎一样,就是把原本对每一个区间的修改,变为了查询,其实都一样。 
        if(depth[tx]<depth[ty])  swap(tx,ty),swap(x,y);
        ans=(ans+query(1,dfn[tx],dfn[x]))%mod;
        x=f[tx];
        tx=top[x],ty=top[ty];
    }
    if(depth[x]<depth[y])  swap(x,y);
    return (ans+query(1,dfn[y],dfn[x]))%mod;
}

对于线段树上的维护,和朴素的线段树一样,就不多说了。

如果题目中说要将以i为根的子树+k,那就直接在线段树上从dfn[i]到dfn[i]+size[i],+k就可以了。

具体看AC代码:(洛谷模板题)

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#define in(a) a=read()
#define REP(i,k,n)  for(int i=k;i<=n;i++)
#define MAXN 100010
using namespace std;
inline int read(){
    int x=0,f=1;
    char ch=getchar();
    for(;!isdigit(ch);ch=getchar())
        if(ch=='-')
            f=-1;
    for(;isdigit(ch);ch=getchar())
        x=x*10+ch-'0';
    return x*f;
}
int n,m,r,mod,input[MAXN];
int total,head[MAXN],to[MAXN<<1],nxt[MAXN<<1];
int size[MAXN],depth[MAXN],f[MAXN],son[MAXN];
int cnt,dfn[MAXN],link[MAXN],top[MAXN];
struct node{
    int l,r,sum,lt;
}tree[MAXN<<2];
inline void adl(int a,int b){
    total++;
    to[total]=b;
    nxt[total]=head[a];
    head[a]=total;
    return ;
}
inline void getson(int u,int fa){//获取每个节点的重儿子 
    size[u]=1;
    for(int e=head[u];e;e=nxt[e])
        if(to[e]!=fa){
            depth[to[e]]=depth[u]+1;
            f[to[e]]=u; 
            getson(to[e],u);
            size[u]+=size[to[e]];//记录以每个节点为根的树的大小 
            if(!son[u] || size[son[u]]<size[to[e]])  son[u]=to[e];//判断后将这个点变为重儿子 
        }
    return ;
}
inline void getdfn(int u,int t){//连成重链,其中我们可以保证,对于每一条重链,它们的dfn值是连续的。t记录的是当前链的链首 
    top[u]=t;//top记录当前链链首 
    dfn[u]=++cnt;//记录dfn值,也是在线段树中的位置 
    link[cnt]=u;//dfn的逆运算,用于建树时的初始赋值 
    if(!son[u])  return ;//如果当前点没有重儿子,说明是这条重链的结束。 
    getdfn(son[u],t);//继续走这条重链 
    for(int e=head[u];e;e=nxt[e])//这个相当于走每一条轻链 
        if(to[e]!=son[u] && to[e]!=f[u])
            getdfn(to[e],to[e]);//重新开始走每一条重链 
    return ;
}
inline void build(int i,int l,int r){//平凡的建树 
    tree[i].l=l,tree[i].r=r;
    if(l==r){
        tree[i].sum=input[link[l]]%mod;//link的作用 
        return ;
    }
    int mid=(l+r)>>1;
    build(i<<1,l,mid);
    build(i<<1|1,mid+1,r);
    tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%mod;
    return ;
}
inline void pushdown(int i){//平凡的pushdown 
    if(!tree[i].lt)  return ;
    tree[i<<1].lt+=tree[i].lt;
    tree[i<<1|1].lt+=tree[i].lt;
    int mid=(tree[i].l+tree[i].r)>>1;
    tree[i<<1].sum=(tree[i<<1].sum+(mid-tree[i].l+1)*tree[i].lt)%mod;
    tree[i<<1|1].sum=(tree[i<<1|1].sum+(tree[i].r-mid)*tree[i].lt)%mod;
    tree[i].lt=0;
    return ;
}
inline void add(int i,int l,int r,int k){//平凡的区间修改 
    if(tree[i].l>=l && tree[i].r<=r){
        tree[i].sum=(tree[i].sum+(tree[i].r-tree[i].l+1)*k)%mod;
        tree[i].lt+=k;
        return ;
    }
    pushdown(i);
    if(tree[i<<1].r>=l)  add(i<<1,l,r,k);
    if(tree[i<<1|1].l<=r) add(i<<1|1,l,r,k);
    tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%mod;
    return ;
}
inline int query(int i,int l,int r){//平凡的区间查询 
    if(tree[i].l>=l && tree[i].r<=r)  return tree[i].sum;
    int sum=0;
    pushdown(i);
    if(tree[i<<1].r>=l)  sum=(sum+query(i<<1,l,r))%mod;
    if(tree[i<<1|1].l<=r)  sum=(sum+query(i<<1|1,l,r))%mod;
    return sum;
}
inline void treeadd(int x,int y,int z){//将题中对树的修改转化成对线段树的修改 
    int tx=top[x],ty=top[y]; 
    while(tx!=ty){//如果两个点不在一条重链上 
        if(depth[tx]<depth[ty])  swap(x,y),swap(tx,ty);//保证x的重链首元素在下方 
        add(1,dfn[tx],dfn[x],z);//从x一直修改到x所在重链的收元素,因为他们在一条重链中,所以在线段树中的位置是连续的。 
        x=f[tx];//走过一条轻链,到上面一个重链的末尾 
        tx=top[x],ty=top[y];//分别更新x、y的重链顶端,准备下一次更新 
    }
    if(depth[x]<depth[y])  swap(x,y);//现在x、y都到了一条重链上了,然后要保证x在下面。 
    add(1,dfn[y],dfn[x],z);//再只用更新他们所在的链就可以了。 
    return ;
}
inline int treesum(int x,int y){//将题中对树查询得指令改为对线段树的查询。 
    int ans=0;
    int tx=top[x],ty=top[y];
    while(tx!=ty){//这一段和修改几乎一样,就是把原本对每一个区间的修改,变为了查询,其实都一样。 
        if(depth[tx]<depth[ty])  swap(tx,ty),swap(x,y);
        ans=(ans+query(1,dfn[tx],dfn[x]))%mod;
        x=f[tx];
        tx=top[x],ty=top[ty];
    }
    if(depth[x]<depth[y])  swap(x,y);
    return (ans+query(1,dfn[y],dfn[x]))%mod;
}
int main(){
    in(n),in(m),in(r),in(mod);
    REP(i,1,n)  in(input[i]);
    int a,b;
    REP(i,1,n-1)  in(a),in(b),adl(a,b),adl(b,a);
    depth[r];
    getson(r,0);
    getdfn(r,r);
    build(1,1,n);
    int p,x,y,z;
    REP(i,1,m){
        in(p);
        if(p==1)  in(x),in(y),in(z),treeadd(x,y,z);
        if(p==2)  in(x),in(y),printf("%d\n",treesum(x,y));
        if(p==3)  in(x),in(z),add(1,dfn[x],dfn[x]+size[x]-1,z);//我们会发现,在树链剖分中,i这颗子树里面所有的节点的dfn都是连续的,我们修改u的子树就是将u到u+size-1修改就可以了。 
        if(p==4)  in(x),printf("%d\n",query(1,dfn[x],dfn[x]+size[x]-1));//查询同上。 
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/jason2003/p/9818242.html