前言
树链剖分,将树的边划分为很多条链,由此降低对树上修改查询等的复杂度。
本次介绍轻重链剖分。
概念:
重儿子:子树的节点最多的儿子,其中如果两个儿子的子树都相同,那么其中任意个。
轻儿子:其余的儿子。
重边:父亲到重儿子的边。
轻边:其余的边。
重链:节点到重儿子的路径。
原理
模板传送门luoguP3384
分析:
要对某一段链或子树进行修改和查询,容易想到使用线段树维护,那么我们应该如何将树进行分解,使得能在放到线段树上进行修改查询。这就要用到树链剖分。
树链剖分
我们通过2次DFS来分解。
第一次DFS
用dep[]表示节点的深度,far[]表示节点的父节点,size[]表示以x为根的子树节点个数,son[]表示x的重儿子。
代码实现:
dep[]好实现,每次DFS时深度+1即可。
far[]也好实现,每次DFS时记下它的父亲,将当前节点和父亲连上即可。
size[]每次初始化为1,DFS后回溯回来时将左右儿子的子树个数累加即可。
递推方程: s i z e [ x ] = s i z e [ l ] + s i z e [ r ] size[x]=size[l]+size[r] size[x]=size[l]+size[r]
son[]类似于点分治找树的重心,由于每个节点的子节点至多有2个,所以我们只需要比较这2个节点谁的size[]大,谁就是重儿子。
代码:
inline void dfs1(int x,int fa,int deep){
dep[x]=deep;
far[x]=fa;
size[x]=1;
int Max=-1;
for(int i=first[x];i;i=nex[i]){
int y=to[i];
if(y==fa) continue;
dfs1(y,x,deep+1);
size[x]+=size[y];
if(size[x]>Max) son[x]=y,Max=size[y];
}
}
第二次DFS
由于我们分解过后使得点对应的编号不连续,我们要使用新的编号来存储。
我们用id[]表示节点x对应的新编号,wt[]表示该节点对应点权值,top[]表示该条链上最顶端的节点。
注意: 每条链都是从轻儿子开始。
代码实现:
id[],wt[]每次dfs更新即可。
top[]由于每条链从轻儿子开始,所以对于轻儿子我们才更新顶端值。
我们发现:从轻儿子出发一直是连重儿子到底,这样就能保证对于一条链上的编号的连续,因为要维护线段树。
所以我们DFS时要先找重儿子,然后再找轻儿子。
代码:
inline void dfs2(int x,int topf){
id[x]=++cnt;
wt[cnt]=w[x];
top[x]=topf;
if(son[x]==0) return;
dfs2(son[x],topf);
for(int i=first[x];i;i=nex[i]){
int y=to[i];
if(y==far[x] || y==son[x]) continue;
dfs2(y,y);
}
}
线段树维护
维护链
对于将某2个节点之间的路径修改或者查询,我们发现:
如果2个节点已经在一条重链上了,因为编号连续,直接在线段树更新即可。
如果2个节点不在一条重链上,即它们的最顶端不相同,我们可以先修改某一链,然后再向跳,看是否在一条重链上,如果在执行1,否则重复。
因为要向上跳,所以我们要从深度更大的点向上。对于同一链,编号也要从小到大,用swap交换即可。
查询操作与之类似。
inline void update1(int x,int y,int val){
val%=mod;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
T.update(1,id[top[x]],id[x],val);
x=far[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
T.update(1,id[x],id[y],val);
}
inline int query1(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=T.query(1,id[top[x]],id[x])%mod;
x=far[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans=(ans+T.query(1,id[x],id[y]))%mod;
return ans;
}
维护子树
因为DFS,所以一棵子树的全部节点是一定连续。
那么最小编号就是这个根,最大编号就根的编号加上子树的个数-1。
inline void update2(int x,int val){
T.update(1,id[x],id[x]+size[x]-1,val);
}
inline int query2(int x){
return T.query(1,id[x],id[x]+size[x]-1)%mod;
}
线段树
线段树就是标准的区间更新,区间查询,要使用lazy_tag
struct TREE{
struct node{
int l,r,w,lz;
}t[N<<2];
inline void pushdown(int k){
if(t[k].lz){
t[lc].lz+=t[k].lz;
t[rc].lz+=t[k].lz;
t[lc].w=(t[lc].w+t[k].lz*(t[lc].r-t[lc].l+1))%mod;
t[rc].w=(t[rc].w+t[k].lz*(t[rc].r-t[rc].l+1))%mod;
t[k].lz=0;
}
}
inline void build(int k,int l,int r){
t[k].l=l,t[k].r=r;
if(l==r){
t[k].w=wt[l]%mod;
return;
}
int mid=(l+r)>>1;
build(lc,l,mid);
build(rc,mid+1,r);
t[k].w=(t[lc].w+t[rc].w)%mod;
}
inline void update(int k,int l,int r,int val){
if(t[k].l>=l && t[k].r<=r){
t[k].lz=(t[k].lz+val)%mod;
t[k].w=(t[k].w+val*(t[k].r-t[k].l+1)%mod)%mod;
return;
}
pushdown(k);
int mid=(t[k].l+t[k].r)>>1;
if(l<=mid) update(lc,l,r,val);
if(r>mid) update(rc,l,r,val);
t[k].w=(t[lc].w+t[rc].w)%mod;
}
inline int query(int k,int l,int r){
if(t[k].l>=l && t[k].r<=r) return t[k].w;
pushdown(k);
int mid=(t[k].l+t[k].r)>>1,sum=0;
if(l<=mid) sum=(sum+query(lc,l,r))%mod;
if(r>mid) sum=(sum+query(rc,l,r))%mod;
return sum;
}
}T;
完整代码
#include<bits/stdc++.h>
using namespace std;
#define lc k<<1
#define rc k<<1|1
const int N=1e5+5,M=2e5+5;
int n,m,rt,mod;
int first[N],nex[M],to[M],w[M],tot;
int son[N],id[N],far[N],cnt,dep[N],size[N],top[N],wt[N];
struct TREE{
struct node{
int l,r,w,lz;
}t[N<<2];
inline void pushdown(int k){
if(t[k].lz){
t[lc].lz+=t[k].lz;
t[rc].lz+=t[k].lz;
t[lc].w=(t[lc].w+t[k].lz*(t[lc].r-t[lc].l+1))%mod;
t[rc].w=(t[rc].w+t[k].lz*(t[rc].r-t[rc].l+1))%mod;
t[k].lz=0;
}
}
inline void build(int k,int l,int r){
t[k].l=l,t[k].r=r;
if(l==r){
t[k].w=wt[l]%mod;
return;
}
int mid=(l+r)>>1;
build(lc,l,mid);
build(rc,mid+1,r);
t[k].w=(t[lc].w+t[rc].w)%mod;
}
inline void update(int k,int l,int r,int val){
if(t[k].l>=l && t[k].r<=r){
t[k].lz=(t[k].lz+val)%mod;
t[k].w=(t[k].w+val*(t[k].r-t[k].l+1)%mod)%mod;
return;
}
pushdown(k);
int mid=(t[k].l+t[k].r)>>1;
if(l<=mid) update(lc,l,r,val);
if(r>mid) update(rc,l,r,val);
t[k].w=(t[lc].w+t[rc].w)%mod;
}
inline int query(int k,int l,int r){
if(t[k].l>=l && t[k].r<=r) return t[k].w;
pushdown(k);
int mid=(t[k].l+t[k].r)>>1,sum=0;
if(l<=mid) sum=(sum+query(lc,l,r))%mod;
if(r>mid) sum=(sum+query(rc,l,r))%mod;
return sum;
}
}T;
inline void add(int x,int y){
nex[++tot]=first[x];
first[x]=tot;
to[tot]=y;
}
inline void update1(int x,int y,int val){
val%=mod;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
T.update(1,id[top[x]],id[x],val);
x=far[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
T.update(1,id[x],id[y],val);
}
inline int query1(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=T.query(1,id[top[x]],id[x])%mod;
x=far[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans=(ans+T.query(1,id[x],id[y]))%mod;
return ans;
}
inline void update2(int x,int val){
T.update(1,id[x],id[x]+size[x]-1,val);
}
inline int query2(int x){
return T.query(1,id[x],id[x]+size[x]-1)%mod;
}
inline void dfs1(int x,int fa,int deep){
dep[x]=deep;
far[x]=fa;
size[x]=1;
int Max=-1;
for(int i=first[x];i;i=nex[i]){
int y=to[i];
if(y==fa) continue;
dfs1(y,x,deep+1);
size[x]+=size[y];
if(size[x]>Max) son[x]=y,Max=size[y];
}
}
inline void dfs2(int x,int topf){
id[x]=++cnt;
wt[cnt]=w[x];
top[x]=topf;
if(son[x]==0) return;
dfs2(son[x],topf);
for(int i=first[x];i;i=nex[i]){
int y=to[i];
if(y==far[x] || y==son[x]) continue;
dfs2(y,y);
}
}
int main(){
scanf("%d%d%d%d",&n,&m,&rt,&mod);
int x,y,z;
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs1(rt,0,1);
dfs2(rt,rt);
T.build(1,1,n);
int cas;
while(m--){
scanf("%d",&cas);
if(cas==1){
scanf("%d%d%d",&x,&y,&z);
update1(x,y,z);
}
else if(cas==2){
scanf("%d%d",&x,&y);
cout<<query1(x,y)<<endl;
}
else if(cas==3){
scanf("%d%d",&x,&y);
update2(x,y);
}
else if(cas==4){
scanf("%d",&x);
cout<<query2(x)<<endl;
}
}
return 0;
}