树链剖分 完美的想法

树链剖分


不知是谁想出的想法,太完美了,首先我大致讲一下树剖的想法。

将树分成重链和轻链,使每条重链越长越好,每次可以用数据结构将重链上的所有节点求出或修改,达到优化的效果,下面我讲的是用线段树维护一棵树。

当然不止是线段树可以维护,树状数组和Splay也可以。

下面看一道题:


洛谷3384

题目描述

如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

输入格式:

第一行包含4个正整数N、M、R、P,分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。
接下来一行包含N个非负整数,分别依次表示各个节点上初始的数值。
接下来N-1行每行包含两个整数x、y,表示点x和点y之间连有一条边(保证无环且连通)
接下来M行每行包含若干个正整数,每行表示一个操作,格式如下:
操作1: 1 x y z
操作2: 2 x y
操作3: 3 x z
操作4: 4 x

输出格式:

输出包含若干行,分别依次表示每个操作2或操作4所得的结果(对P取模

输入样例#1:

5 5 2 24
7 3 7 8 0 
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3

输出样例#1:

2
21

说明

时空限制:1s,128M

数据规模:

对于30%的数据: N \leq 10, M \leq 10N≤10,M≤10

对于70%的数据: N \leq {10}^3, M \leq {10}^3N≤103,M≤103

对于100%的数据: N \leq {10}^5, M \leq {10}^5N≤105,M≤105

其实,纯随机生成的树LCA+暴力是能过的,可是,你觉得可能是纯随机的么233

样例说明:

树的结构如下:

img

各个操作如下:

img

故输出应依次为2、21(重要的事情说三遍:记得取模)


树剖可以解决树上路径修改和子树修改,复杂度应该是 l o g 2 N

黑色边表示重链,当然重链是一条链,每个节点有且仅有一条重链与它的子节点相连。

先介绍变量:

int n,m,Rot,MOD,W[MAXN],a[MAXN];//Rot是根节点,MOD是要%的数
int Tre[MAXN<<2],Add[MAXN<<2];//线段树
struct Edge{//这是存边的
    int tot,lnk[MAXN],son[2*MAXN],nxt[2*MAXN];
    void Add(int x,int y){son[++tot]=y;nxt[tot]=lnk[x];lnk[x]=tot;}
}E;
int cnt,ID[MAXN];//存DFS序
int Dep[MAXN];//深度
int Son[MAXN];//重儿子,也就是重链连接的子节点
int Siz[MAXN];//子树大小
int Fa[MAXN];//父节点
int Top[MAXN];//这条重链顶端的节点,可以优化查找

预处理1:

标记重儿子节点,得到深度和子树大小。

void First(int x,int f,int D){
    Dep[x]=D,Siz[x]=1,Fa[x]=f;
    for(int j=E.lnk[x],NowSize=0;j;j=E.nxt[j])
    if(E.son[j]!=f){
        First(E.son[j],x,D+1);Siz[x]+=Siz[E.son[j]];
        if(Siz[E.son[j]]>NowSize) NowSize=Siz[E.son[j]],Son[x]=E.son[j];//我们要使重链越长越好,那么就选子树最大的连接。
    }
}

预处理2:

求出DFS序,用线段树维护。

我们以先根,然后重儿子,然后轻儿子的顺序DFS序,这样可以保证一条重链是连续的节点,方便更新重链。

void Second(int x,int top){
    ID[x]=++cnt,W[cnt]=a[x],Top[x]=top;
    if(!Son[x]) return;Second(Son[x],top);
    for(int j=E.lnk[x];j;j=E.nxt[j])
    if(!ID[E.son[j]]) Second(E.son[j],E.son[j]);
}

以下是区间修改线段树

void PushDown(int x,int Ln,int Rn){
    if(!Add[x]) return;
    Tre[x<<1]=(Tre[x<<1]+Ln*Add[x]%MOD)%MOD,Tre[x<<1|1]=(Tre[x<<1|1]+Rn*Add[x]%MOD)%MOD;
    Add[x<<1]=(Add[x<<1]+Add[x])%MOD,Add[x<<1|1]=(Add[x<<1|1]+Add[x])%MOD,Add[x]=0;
}
void Build(int x,int l,int r){
    if(l==r){Tre[x]=W[l]%MOD;return;}//W是DFS序过后的a数组         
    int mid=(r+l)>>1;
    Build(x<<1,l,mid);Build(x<<1|1,mid+1,r);
    Tre[x]=(Tre[x<<1]+Tre[x<<1|1])%MOD;
}
void Updata(int x,int l,int r,int L,int R,int p){
    if(L<=l&&r<=R){Tre[x]=(Tre[x]+(r-l+1)*p)%MOD;Add[x]=(Add[x]+p)%MOD;return;}
    int mid=(r+l)>>1;PushDown(x,(mid-l+1)%MOD,(r-mid)%MOD);
    if(L<=mid) Updata(x<<1,l,mid,L,R,p);
    if(R>mid) Updata(x<<1|1,mid+1,r,L,R,p);
    Tre[x]=(Tre[x<<1]+Tre[x<<1|1])%MOD;
}
int Query(int x,int l,int r,int L,int R){
    if(L<=l&&r<=R) return Tre[x]%MOD;
    int mid=(r+l)>>1;PushDown(x,(mid-l+1)%MOD,(r-mid)%MOD);
    int Ans=0;
    if(L<=mid) Ans+=Query(x<<1,l,mid,L,R);
    if(R>mid) Ans+=Query(x<<1|1,mid+1,r,L,R);
    return Ans%MOD;
}

将x这棵子树全部加上p

我们可以知道这棵子树所在的区间是ID[x]~ID[x]+Siz[x]-1,很好理解,因为这是DFS序后的。

Updata(1,1,n,ID[x],ID[x]+Siz[x]-1,z%MOD);

求x这棵子树

Query(1,1,n,ID[x],ID[x]+Siz[x]-1)

将这棵子树x到y的路径上都加上p

用类似LCA的想法,每次将重链上的所有点都加上p。

void Insert(int x,int y,int p){
    while(Top[x]!=Top[y]){
        if(Dep[Top[x]]>Dep[Top[y]]) swap(x,y);
        Updata(1,1,n,ID[Top[y]],ID[y],p);y=Fa[Top[y]];
    }
    if(Dep[x]>Dep[y]) swap(x,y);
    Updata(1,1,n,ID[x],ID[y],p);
}

求这棵子树x到y的路径的加和

还是一样

int AskSum(int x,int y){
    int ret=0;
    while(Top[x]!=Top[y]){
        if(Dep[Top[x]]>Dep[Top[y]]) swap(x,y);
        ret=(ret+Query(1,1,n,ID[Top[y]],ID[y]))%MOD;y=Fa[Top[y]];
    }
    if(Dep[x]>Dep[y]) swap(x,y);
    ret=(ret+Query(1,1,n,ID[x],ID[y]))%MOD;
    return ret;
}

下面贴上完整代码

#include<cstdio>
#include<cctype>
#include<iostream>
#include<algorithm>
#define MAXN 100005
using namespace std;
int n,m,Rot,MOD,W[MAXN],a[MAXN],Tre[MAXN<<2],Add[MAXN<<2];
struct Edge{
    int tot,lnk[MAXN],son[2*MAXN],nxt[2*MAXN];
    void Add(int x,int y){son[++tot]=y;nxt[tot]=lnk[x];lnk[x]=tot;}
}E;
int cnt,ID[MAXN],Dep[MAXN],Son[MAXN],Siz[MAXN],Fa[MAXN],Top[MAXN];
int read(){
    int ret=0;char ch=getchar();bool f=1;
    for(;!isdigit(ch);ch=getchar()) f^=!(ch^'-');
    for(; isdigit(ch);ch=getchar()) ret=(ret<<3)+(ret<<1)+ch-48;
    return f?ret:-ret;
}
void First(int x,int f,int D){
    Dep[x]=D,Siz[x]=1,Fa[x]=f;
    for(int j=E.lnk[x],NowSize=0;j;j=E.nxt[j])
    if(E.son[j]!=f){
        First(E.son[j],x,D+1);Siz[x]+=Siz[E.son[j]];
        if(Siz[E.son[j]]>NowSize) NowSize=Siz[E.son[j]],Son[x]=E.son[j];
    }
}
void Second(int x,int top){
    ID[x]=++cnt,W[cnt]=a[x],Top[x]=top;
    if(!Son[x]) return;Second(Son[x],top);
    for(int j=E.lnk[x];j;j=E.nxt[j])
    if(!ID[E.son[j]]) Second(E.son[j],E.son[j]);
}
void PushDown(int x,int Ln,int Rn){
    if(!Add[x]) return;
    Tre[x<<1]=(Tre[x<<1]+Ln*Add[x]%MOD)%MOD,Tre[x<<1|1]=(Tre[x<<1|1]+Rn*Add[x]%MOD)%MOD;
    Add[x<<1]=(Add[x<<1]+Add[x])%MOD,Add[x<<1|1]=(Add[x<<1|1]+Add[x])%MOD,Add[x]=0;
}
void Build(int x,int l,int r){
    if(l==r){Tre[x]=W[l]%MOD;return;}            
    int mid=(r+l)>>1;
    Build(x<<1,l,mid);Build(x<<1|1,mid+1,r);
    Tre[x]=(Tre[x<<1]+Tre[x<<1|1])%MOD;
}
void Updata(int x,int l,int r,int L,int R,int p){
    if(L<=l&&r<=R){Tre[x]=(Tre[x]+(r-l+1)*p)%MOD;Add[x]=(Add[x]+p)%MOD;return;}
    int mid=(r+l)>>1;PushDown(x,(mid-l+1)%MOD,(r-mid)%MOD);
    if(L<=mid) Updata(x<<1,l,mid,L,R,p);
    if(R>mid) Updata(x<<1|1,mid+1,r,L,R,p);
    Tre[x]=(Tre[x<<1]+Tre[x<<1|1])%MOD;
}
int Query(int x,int l,int r,int L,int R){
    if(L<=l&&r<=R) return Tre[x]%MOD;
    int mid=(r+l)>>1;PushDown(x,(mid-l+1)%MOD,(r-mid)%MOD);
    int Ans=0;
    if(L<=mid) Ans+=Query(x<<1,l,mid,L,R);
    if(R>mid) Ans+=Query(x<<1|1,mid+1,r,L,R);
    return Ans%MOD;
}
void Insert(int x,int y,int p){
    while(Top[x]!=Top[y]){
        if(Dep[Top[x]]>Dep[Top[y]]) swap(x,y);
        Updata(1,1,n,ID[Top[y]],ID[y],p);y=Fa[Top[y]];
    }
    if(Dep[x]>Dep[y]) swap(x,y);
    Updata(1,1,n,ID[x],ID[y],p);
}
int AskSum(int x,int y){
    int ret=0;
    while(Top[x]!=Top[y]){
        if(Dep[Top[x]]>Dep[Top[y]]) swap(x,y);
        ret=(ret+Query(1,1,n,ID[Top[y]],ID[y]))%MOD;y=Fa[Top[y]];
    }
    if(Dep[x]>Dep[y]) swap(x,y);
    ret=(ret+Query(1,1,n,ID[x],ID[y]))%MOD;
    return ret;
}
int main(){
    #ifndef ONLINE_JUDGE
    freopen("prob.in","r",stdin);
    freopen("prob.out","w",stdout);
    #endif
    n=read();m=read();Rot=read();MOD=read();
    for(int i=1;i<=n;i++) a[i]=read(),a[i]%=MOD;
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        E.Add(x,y),E.Add(y,x);
    }
    First(Rot,0,1);Second(Rot,Rot);Build(1,1,n);
    for(int i=1;i<=m;i++){
        int opt=read();
        if(opt==1){
            int x=read(),y=read(),z=read();
            Insert(x,y,z%MOD);
        }else
        if(opt==2){
            int x=read(),y=read();
            printf("%d\n",AskSum(x,y));
        }else
        if(opt==3){
            int x=read(),z=read();
            Updata(1,1,n,ID[x],ID[x]+Siz[x]-1,z%MOD);
        }else{
            int x=read();
            printf("%d\n",Query(1,1,n,ID[x],ID[x]+Siz[x]-1));
        }
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_41357771/article/details/80586456