学习笔记:树链剖分

前言

树链剖分,将树的边划分为很多条链,由此降低对树上修改查询等的复杂度。
本次介绍轻重链剖分。
概念:
重儿子:子树的节点最多的儿子,其中如果两个儿子的子树都相同,那么其中任意个。
轻儿子:其余的儿子。
重边:父亲到重儿子的边。
轻边:其余的边。
重链:节点到重儿子的路径。

原理

模板传送门luoguP3384
分析:
要对某一段链或子树进行修改和查询,容易想到使用线段树维护,那么我们应该如何将树进行分解,使得能在放到线段树上进行修改查询。这就要用到树链剖分。

树链剖分

我们通过2次DFS来分解。

第一次DFS

用dep[]表示节点的深度,far[]表示节点的父节点,size[]表示以x为根的子树节点个数,son[]表示x的重儿子。
代码实现:
dep[]好实现,每次DFS时深度+1即可。
far[]也好实现,每次DFS时记下它的父亲,将当前节点和父亲连上即可。
size[]每次初始化为1,DFS后回溯回来时将左右儿子的子树个数累加即可。
递推方程: s i z e [ x ] = s i z e [ l ] + s i z e [ r ] size[x]=size[l]+size[r] size[x]=size[l]+size[r]
son[]类似于点分治找树的重心,由于每个节点的子节点至多有2个,所以我们只需要比较这2个节点谁的size[]大,谁就是重儿子。
代码:

inline void dfs1(int x,int fa,int deep){
    
    
	dep[x]=deep;
	far[x]=fa;
	size[x]=1;
	int Max=-1;
	for(int i=first[x];i;i=nex[i]){
    
    
		int y=to[i];
		if(y==fa) continue;
		dfs1(y,x,deep+1);
		size[x]+=size[y];
		if(size[x]>Max) son[x]=y,Max=size[y];
	}
}

第二次DFS

由于我们分解过后使得点对应的编号不连续,我们要使用新的编号来存储。
我们用id[]表示节点x对应的新编号,wt[]表示该节点对应点权值,top[]表示该条链上最顶端的节点。
注意: 每条链都是从轻儿子开始。
代码实现:
id[],wt[]每次dfs更新即可。
top[]由于每条链从轻儿子开始,所以对于轻儿子我们才更新顶端值。
我们发现:从轻儿子出发一直是连重儿子到底,这样就能保证对于一条链上的编号的连续,因为要维护线段树。
所以我们DFS时要先找重儿子,然后再找轻儿子。
代码:

inline void dfs2(int x,int topf){
    
    
	id[x]=++cnt;
	wt[cnt]=w[x];
	top[x]=topf;
	if(son[x]==0) return;
	dfs2(son[x],topf);
	for(int i=first[x];i;i=nex[i]){
    
    
		int y=to[i];
		if(y==far[x] || y==son[x]) continue;
		dfs2(y,y);
	}
}

线段树维护

维护链

对于将某2个节点之间的路径修改或者查询,我们发现:
如果2个节点已经在一条重链上了,因为编号连续,直接在线段树更新即可。
如果2个节点不在一条重链上,即它们的最顶端不相同,我们可以先修改某一链,然后再向跳,看是否在一条重链上,如果在执行1,否则重复。
因为要向上跳,所以我们要从深度更大的点向上。对于同一链,编号也要从小到大,用swap交换即可。
查询操作与之类似。

inline void update1(int x,int y,int val){
    
    
	val%=mod;
	while(top[x]!=top[y]){
    
    
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		T.update(1,id[top[x]],id[x],val);
		x=far[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	T.update(1,id[x],id[y],val);
}

inline int query1(int x,int y){
    
    
	int ans=0;
	while(top[x]!=top[y]){
    
    
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		ans+=T.query(1,id[top[x]],id[x])%mod;
		x=far[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	ans=(ans+T.query(1,id[x],id[y]))%mod;
	return ans;
}

维护子树

因为DFS,所以一棵子树的全部节点是一定连续。
那么最小编号就是这个根,最大编号就根的编号加上子树的个数-1。

inline void update2(int x,int val){
    
    
	T.update(1,id[x],id[x]+size[x]-1,val);
}

inline int query2(int x){
    
    
	return T.query(1,id[x],id[x]+size[x]-1)%mod;
}

线段树

线段树就是标准的区间更新,区间查询,要使用lazy_tag

struct TREE{
    
    
	struct node{
    
    
		int l,r,w,lz;
	}t[N<<2];
	inline void pushdown(int k){
    
    
		if(t[k].lz){
    
    
			t[lc].lz+=t[k].lz;
			t[rc].lz+=t[k].lz;
			t[lc].w=(t[lc].w+t[k].lz*(t[lc].r-t[lc].l+1))%mod;
			t[rc].w=(t[rc].w+t[k].lz*(t[rc].r-t[rc].l+1))%mod;
			t[k].lz=0;
		}
	}
	inline void build(int k,int l,int r){
    
    
		t[k].l=l,t[k].r=r;
		if(l==r){
    
    
			t[k].w=wt[l]%mod;
			return;
		}
		int mid=(l+r)>>1;
		build(lc,l,mid);
		build(rc,mid+1,r);
		t[k].w=(t[lc].w+t[rc].w)%mod;
	}
	inline void update(int k,int l,int r,int val){
    
    
		if(t[k].l>=l && t[k].r<=r){
    
    
			t[k].lz=(t[k].lz+val)%mod;
			t[k].w=(t[k].w+val*(t[k].r-t[k].l+1)%mod)%mod;
			return;
		}
		pushdown(k);
		int mid=(t[k].l+t[k].r)>>1;
		if(l<=mid) update(lc,l,r,val);
		if(r>mid) update(rc,l,r,val);
		t[k].w=(t[lc].w+t[rc].w)%mod;
	}
	inline int query(int k,int l,int r){
    
    
		if(t[k].l>=l && t[k].r<=r) return t[k].w;
		pushdown(k);
		int mid=(t[k].l+t[k].r)>>1,sum=0;
		if(l<=mid) sum=(sum+query(lc,l,r))%mod;
		if(r>mid) sum=(sum+query(rc,l,r))%mod;
		return sum;
	}
}T;

完整代码

#include<bits/stdc++.h>
using namespace std;
#define lc k<<1
#define rc k<<1|1

const int N=1e5+5,M=2e5+5;
int n,m,rt,mod;
int first[N],nex[M],to[M],w[M],tot;
int son[N],id[N],far[N],cnt,dep[N],size[N],top[N],wt[N];

struct TREE{
    
    
	struct node{
    
    
		int l,r,w,lz;
	}t[N<<2];
	inline void pushdown(int k){
    
    
		if(t[k].lz){
    
    
			t[lc].lz+=t[k].lz;
			t[rc].lz+=t[k].lz;
			t[lc].w=(t[lc].w+t[k].lz*(t[lc].r-t[lc].l+1))%mod;
			t[rc].w=(t[rc].w+t[k].lz*(t[rc].r-t[rc].l+1))%mod;
			t[k].lz=0;
		}
	}
	inline void build(int k,int l,int r){
    
    
		t[k].l=l,t[k].r=r;
		if(l==r){
    
    
			t[k].w=wt[l]%mod;
			return;
		}
		int mid=(l+r)>>1;
		build(lc,l,mid);
		build(rc,mid+1,r);
		t[k].w=(t[lc].w+t[rc].w)%mod;
	}
	inline void update(int k,int l,int r,int val){
    
    
		if(t[k].l>=l && t[k].r<=r){
    
    
			t[k].lz=(t[k].lz+val)%mod;
			t[k].w=(t[k].w+val*(t[k].r-t[k].l+1)%mod)%mod;
			return;
		}
		pushdown(k);
		int mid=(t[k].l+t[k].r)>>1;
		if(l<=mid) update(lc,l,r,val);
		if(r>mid) update(rc,l,r,val);
		t[k].w=(t[lc].w+t[rc].w)%mod;
	}
	inline int query(int k,int l,int r){
    
    
		if(t[k].l>=l && t[k].r<=r) return t[k].w;
		pushdown(k);
		int mid=(t[k].l+t[k].r)>>1,sum=0;
		if(l<=mid) sum=(sum+query(lc,l,r))%mod;
		if(r>mid) sum=(sum+query(rc,l,r))%mod;
		return sum;
	}
}T;

inline void add(int x,int y){
    
    
	nex[++tot]=first[x];
	first[x]=tot;
	to[tot]=y;
}

inline void update1(int x,int y,int val){
    
    
	val%=mod;
	while(top[x]!=top[y]){
    
    
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		T.update(1,id[top[x]],id[x],val);
		x=far[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	T.update(1,id[x],id[y],val);
}

inline int query1(int x,int y){
    
    
	int ans=0;
	while(top[x]!=top[y]){
    
    
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		ans+=T.query(1,id[top[x]],id[x])%mod;
		x=far[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	ans=(ans+T.query(1,id[x],id[y]))%mod;
	return ans;
}

inline void update2(int x,int val){
    
    
	T.update(1,id[x],id[x]+size[x]-1,val);
}

inline int query2(int x){
    
    
	return T.query(1,id[x],id[x]+size[x]-1)%mod;
}

inline void dfs1(int x,int fa,int deep){
    
    
	dep[x]=deep;
	far[x]=fa;
	size[x]=1;
	int Max=-1;
	for(int i=first[x];i;i=nex[i]){
    
    
		int y=to[i];
		if(y==fa) continue;
		dfs1(y,x,deep+1);
		size[x]+=size[y];
		if(size[x]>Max) son[x]=y,Max=size[y];
	}
}

inline void dfs2(int x,int topf){
    
    
	id[x]=++cnt;
	wt[cnt]=w[x];
	top[x]=topf;
	if(son[x]==0) return;
	dfs2(son[x],topf);
	for(int i=first[x];i;i=nex[i]){
    
    
		int y=to[i];
		if(y==far[x] || y==son[x]) continue;
		dfs2(y,y);
	}
}

int main(){
    
    
	scanf("%d%d%d%d",&n,&m,&rt,&mod);
	int x,y,z;
	for(int i=1;i<=n;i++) scanf("%d",&w[i]);
	for(int i=1;i<n;i++){
    
    
		scanf("%d%d",&x,&y);
		add(x,y);
		add(y,x);
	}
	dfs1(rt,0,1);
	dfs2(rt,rt);
	T.build(1,1,n);
	int cas;
	while(m--){
    
    
		scanf("%d",&cas);
		if(cas==1){
    
    
			scanf("%d%d%d",&x,&y,&z);
			update1(x,y,z);
		}
		else if(cas==2){
    
    
			scanf("%d%d",&x,&y);
			cout<<query1(x,y)<<endl;
		}
		else if(cas==3){
    
    
			scanf("%d%d",&x,&y);
			update2(x,y);
		}
		else if(cas==4){
    
    
			scanf("%d",&x);
			cout<<query2(x)<<endl;
		}
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/pigonered/article/details/121256412
今日推荐