树链剖分——入门及模板题

一篇很好的博客 https://www.cnblogs.com/ivanovcraft/p/9019090.html

树链剖分:将树分割成一条条链,然后按照dfs序进行维护

为什么要进行树链剖分?
首先来说一般的dfs序:可以想象普通的dfs序只能保证同一子树的结点序号是连续的
但是这样的dfs并没有很好的性质,比如我们要处理从root点到某个叶子结点上的路径,普通的dfs序就无法维护这样的路径
例如root有三个子节点a,b,c,子节点a有两个儿子d,e 子节点b有两个儿子f,g
现在我们要访问指定的一条路径root->d,但是dfs时root可能先访问的是b,然后再访问a
这样访问路径就必须通过暴力lca来进行dfs序的询问,而暴力lca的复杂度会达到O(n)

那么我们要使用一种方法来优化向上爬的复杂度
我们考虑dfs时先dfs到size最大的点那里——称为重儿子,dfs到底后这条链就是重链
不被重链包含的边称为轻边
然后再去dfs其他子节点,dfs的过程和根节点的过程一样
那么树的dfs序就可以看成是很多条不相交的重链,被轻链连在一起

这样的树链有两个性质:
轻边(u,v) 那么size[u]/2>size[v],且轻边
从根节点到任意结点的路径上的轻重链数量是logn级的
感性体会一下这两个性质:要经过轻边(u,v),说明v的结点个数不会超过size[u]/2
那么每次出现一条轻边,其下面的子树规模都要除以2,所以到人以结点,轻边的数量是logn级的
那么可知轻重链的数量也是logn级的

有了这两个性质,我们再来看向上爬的问题,非根结点x到根节点的路径上最多交替出现logn条重链
而重链的dfs序恰好是连在一起的,那么就可以一起维护了
所以这样来看从x到根节点只要维护logn次即可

对于树上任意两点(x,y),只要维护两条到根节点的路径即可!

第一次dfs处理出size数组,fa数组,重儿子数组son,deep数组(在向上爬的时候要用到,这里顺便处理了)
第二次dfs处理出轻重链:idx数组(即dfs序,最好再处理个反向的has数组,建立线段树时要用),top数组(维护每个点所在链的顶端)

树链剖分最简单的应用就是求lca啦!

一些数组的定义

int f[maxn]; //父亲 
int d[maxn]; //深度 
int son[maxn]; //重儿子 
int size[maxn]; //大小 
int top[maxn]; //所在链的顶端 
int id[maxn]; //dfs序 
int rk[maxn]; //dfs序对应的结点编号 
int cnt; //dfs的时间戳 

轻重链剖分

void dfs1(int x,int pre,int deep){
    size[x]=1,d[x]=deep;
    for(int i=head[x];i!=-1;i=edge[i].nxt){
        int y=edge[i].to;
        if(y==pre)continue;
        f[y]=x;dfs1(y,x,deep+1);size[x]+=size[y];
        if(size[son[x]]<size[y])son[x]=y;
    }
}
void dfs2(int x,int tp){    
    top[x]=tp;id[x]=++cnt;rk[cnt]=x;
    if(son[x])dfs2(son[x],tp);
    for(int i=head[x];i!=-1;i=edge[i].nxt){
        int y=edge[i].to;
        if(y!=son[x] && y!=f[x])dfs2(y,y);
    }
}

更新任意一条链(x,y):加上某个值

void updates(int x,int y,int c){
    while(top[x]!=top[y]){
        if(d[top[x]]<d[top[y]])swap(x,y);
        update(id[top[x]],id[x],1,n,1,c);
        x=f[top[x]];
    }
    if(id[x]>id[y])swap(x,y);
    update(id[x],id[y],1,n,1,c);
}

询问任意一条链(x,y):求和

ll Sum(int x,int y){
    ll res=0;
    while(top[x]!=top[y]){
        if(d[top[x]]<d[top[y]])swap(x,y); 
        res=(res+query(id[top[x]],id[x],1,n,1))%mod;
        x=f[top[x]];
    }
    if(id[x]>id[y])swap(x,y);
    return (res+query(id[x],id[y],1,n,1))%mod; 
}

然后是完整的代码 https://www.luogu.org/problemnew/show/P3384

#include<bits/stdc++.h>
using namespace std;
#define maxn 200006
#define ll long long
struct Edge{int to,nxt;}edge[maxn<<1];
ll n,m,head[maxn],tot,r,v[maxn],mod;
void init(){memset(head,-1,sizeof head);tot=0;}
void addedge(int u,int v){edge[tot].nxt=head[u],edge[tot].to=v;head[u]=tot++;}

int f[maxn];//父亲 
int d[maxn];//深度 
int son[maxn];//重儿子 
int size[maxn];//大小 
int top[maxn];//所在链的顶端 
int id[maxn];//dfs序 
int rk[maxn];//dfs序对应的结点编号 
int cnt;//dfs的时间戳 
 
void dfs1(int x,int pre,int deep){
    size[x]=1,d[x]=deep;
    for(int i=head[x];i!=-1;i=edge[i].nxt){
        int y=edge[i].to;
        if(y==pre)continue;
        f[y]=x;dfs1(y,x,deep+1);size[x]+=size[y];
        if(size[son[x]]<size[y])son[x]=y;
    }
}
void dfs2(int x,int tp){    
    top[x]=tp;id[x]=++cnt;rk[cnt]=x;
    if(son[x])dfs2(son[x],tp);
    for(int i=head[x];i!=-1;i=edge[i].nxt){
        int y=edge[i].to;
        if(y!=son[x] && y!=f[x])dfs2(y,y);
    }
}
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
ll sum[maxn<<2],lazy[maxn<<2];
void pushup(int rt){sum[rt]=sum[rt<<1]+sum[rt<<1|1];}
void pushdown(int l,int r,int rt){
    if(lazy[rt]){
        int m=l+r>>1;
        sum[rt<<1]=(sum[rt<<1]+lazy[rt]*(m-l+1)%mod)%mod;
        sum[rt<<1|1]=(sum[rt<<1|1]+lazy[rt]*(r-m)%mod)%mod;
        lazy[rt<<1]=(lazy[rt<<1]+lazy[rt])%mod;
        lazy[rt<<1|1]=(lazy[rt<<1|1]+lazy[rt])%mod;
        lazy[rt]=0;
    }
}
void build(int l,int r,int rt){
    if(l==r){sum[rt]=v[rk[l]];return;}
    int m=l+r>>1;
    build(lson);build(rson);
    pushup(rt);
}
void update(int L,int R,int l,int r,int rt,ll val){
    if(L<=l && R>=r){
        lazy[rt]=(lazy[rt]+val)%mod;
        sum[rt]=(sum[rt]+val*(r-l+1))%mod;
        return;
    } 
    pushdown(l,r,rt);
    int m=l+r>>1;
    if(L<=m)update(L,R,lson,val);
    if(R>m)update(L,R,rson,val);
    pushup(rt);
}
ll query(int L,int R,int l,int r,int rt){
    if(L<=l && R>=r)return sum[rt];
    pushdown(l,r,rt);
    int m=l+r>>1;ll res=0; 
    if(L<=m)res=(res+query(L,R,lson))%mod;
    if(R>m)res=(res+query(L,R,rson))%mod;
    return res;
    
}
ll Sum(int x,int y){
    ll res=0;
    while(top[x]!=top[y]){
        if(d[top[x]]<d[top[y]])swap(x,y); 
        res=(res+query(id[top[x]],id[x],1,n,1))%mod;
        x=f[top[x]];
    }
    if(id[x]>id[y])swap(x,y);
    return (res+query(id[x],id[y],1,n,1))%mod; 
}
void updates(int x,int y,int c){
    while(top[x]!=top[y]){
        if(d[top[x]]<d[top[y]])swap(x,y);
        update(id[top[x]],id[x],1,n,1,c);
        x=f[top[x]];
    }
    if(id[x]>id[y])swap(x,y);
    update(id[x],id[y],1,n,1,c);
}

int main(){
    init();
    scanf("%lld%lld%lld%lld",&n,&m,&r,&mod); 
    for(int i=1;i<=n;i++)scanf("%lld",&v[i]);
    ll x,y;
    for(int j=1;j<n;j++){
        scanf("%lld%lld",&x,&y);
        addedge(x,y);addedge(y,x);
    }
    cnt=0;dfs1(r,0,1);dfs2(r,r);
    build(1,n,1);
    while(m--){
        ll op,x,y,k;
        scanf("%lld",&op);
        if(op==1){//x->y路径上+k
            scanf("%lld%lld%lld",&x,&y,&k); 
            updates(x,y,k);
        }
        else if(op==2){//x->y上的sum 
            scanf("%lld%lld",&x,&y);
            cout<<Sum(x,y)<<endl; 
        }
        else if(op==3){//x子树上的所有点都+k
            scanf("%lld%lld",&x,&k);
            update(id[x],id[x]+size[x]-1,1,n,1,k); 
        }
        else {//查询x子树下的和 
            scanf("%lld",&x);
            cout<<query(id[x],id[x]+size[x]-1,1,n,1)<<endl;
        }
    }
}
View Code

猜你喜欢

转载自www.cnblogs.com/zsben991126/p/10754909.html