dsu on tree备忘

dsu on tree是一种处理树上不带修改,询问子树有关的信息的问题的思想,可以被称为静态链分治(这个称呼比较符合这个算法的特点)

算法实现

对于一个节点:
先递归处理轻儿子。
然后递归处理重儿子。
计算当前节点的答案,这里需要遍历所有轻儿子(本来需要遍历整个子树,因为重儿子的影响我们没有清除,所以不用遍历重儿子的子树)
如果该节点本身是轻儿子,那么就需要清除该节点内所有子树的影响。

大概模板:

void dfs1(int u,int fa){
	sz[u]=1;
	for(int i=h[u];i;i=e[i].p){
		int v=e[i].v;
		if(v==fa)continue;
		dfs1(v,u);sz[u]+=sz[v];
		if(sz[son[u]]<sz[v])son[u]=v;
	}
}
void calc(int u,int fa,int val){
	if(val>0){
		modify(alfa[u],-1);//先减去原本的贡献
		change(alfa[u],1);
		modify(alfa[u],1);//加上现在的
	}
	else{
		modify(alfa[u],-1);
		change(alfa[u],-1);
		modify(alfa[u],1);
	}
	for(int i=h[u];i;i=e[i].p){
		int v=e[i].v;
		if(v==fa||vis[v])continue;
		calc(v,u,val)
	}
}
void dfs2(int u,int fa,int flag){
	for(int i=h[u];i;i=e[i].p){
		int v=e[i].v;
		if(v==fa||v==son[u])continue;
		dfs2(v,u,0);
	}
	if(son[u]){
		dfs2(son[u],u,1);vis[son[u]]=1;
	}
	calc(u,fa,1);vis[son[u]]=0;
	ans[u]=query();//此处根据题目要求
	if(!flag)calc(u,fa,0);
}

例题

树上数颜色
这个是最简单的应用了,alfa记录的是每种颜色的出现次数

#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+5;
inline int read(){
	char c=getchar();int t=0,f=1;
	while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
	while(isdigit(c)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
	return t*f;
}
int n,b,h[maxn],cnt,c[maxn];
struct edge{
	int v,p;
}e[maxn<<1];
inline void add(int a,int b){
	e[++cnt].p=h[a];
	e[cnt].v=b;
	h[a]=cnt;
	e[++cnt].p=h[b];
	e[cnt].v=a;
	h[b]=cnt;
}
int sz[maxn],son[maxn];
void dfs1(int u,int fa){
	sz[u]=1;
	for(int i=h[u];i;i=e[i].p){
		int v=e[i].v;
		if(v==fa)continue;
		dfs1(v,u);sz[u]+=sz[v];
		if(sz[son[u]]<sz[v])son[u]=v;
	}
}
int m,alfa[maxn],tot,ans[maxn],vis[maxn];
void calc(int u,int fa,int val){
	if(val>0){
		if(!alfa[c[u]])++tot;
		alfa[c[u]]++;
	}
	else{
		if(alfa[c[u]]==1)tot--;
		alfa[c[u]]--;
	}
	for(int i=h[u];i;i=e[i].p){
		int v=e[i].v;
		if(v==fa||vis[v])continue;
		calc(v,u,val);
	}
}
void dfs2(int u,int fa,int flag){
	for(int i=h[u];i;i=e[i].p){
		int v=e[i].v;
		if(v==fa||v==son[u])continue;
		dfs2(v,u,0);
	}
	if(son[u]){
		dfs2(son[u],u,1);vis[son[u]]=1;
	}
	calc(u,fa,1);vis[son[u]]=false;
	ans[u]=tot;
	if(!flag)calc(u,fa,-1);
}
int main(){
	n=read();
	for(int i=1;i<n;i++){
		int a=read(),b=read();
		add(a,b);
	}
	dfs1(1,0);
	for(int i=1;i<=n;i++)c[i]=read();
	dfs2(1,0,1);
	m=read();
	while(m--){
		int x=read();
		printf("%d\n",ans[x]);
	}
	return 0;
}

CF600E
和上题差距不大,除了记录alfa意外,还记录了beta,表示出现次数为x的颜色的编号之和

#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+5;
inline int read(){
	char c=getchar();int t=0,f=1;
	while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
	while(isdigit(c)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
	return t*f;
}
int n,c[maxn];
struct edge{
	int v,p;
}e[maxn<<1];
int h[maxn],cnt;
inline void add(int a,int b){
	e[++cnt].p=h[a];
	e[cnt].v=b;
	h[a]=cnt;
	e[++cnt].p=h[b];
	e[cnt].v=a;
	h[b]=cnt;
}
int sz[maxn],alfa[maxn],tot,son[maxn];
void dfs1(int u,int fa){
	sz[u]=1;
	for(int i=h[u];i;i=e[i].p){
		int v=e[i].v;
		if(v==fa)continue;
		dfs1(v,u);
		sz[u]+=sz[v];if(sz[son[u]]<sz[v])son[u]=v;
	}
}
long long sum,beta[maxn<<2],ans[maxn];
#define lc rt<<1
#define rc rt<<1|1
void modify(int rt,int l,int r,int x,int val){
	beta[rt]+=val;
	if(l==r){return ;}
	int mid=l+r>>1;
	if(x<=mid)modify(lc,l,mid,x,val);
	else modify(rc,mid+1,r,x,val);
}
long long query(int rt,int l,int r){
	if(l==r){return beta[rt];}
	int mid=l+r>>1;
	if(beta[rc])return query(rc,mid+1,r);
	else return query(lc,l,mid);
}
int vis[maxn];
void build(int rt,int l,int r){
	beta[rt]=sum;
	if(l==r){return ;}
	int mid=l+r>>1;
	build(lc,l,mid);
}
void calc(int u,int fa,int val){
	if(val>0){
		modify(1,0,n,alfa[c[u]],-c[u]);
		alfa[c[u]]++;
		modify(1,0,n,alfa[c[u]],c[u]);
	}
	else{
		modify(1,0,n,alfa[c[u]],-c[u]);
		alfa[c[u]]--;
		modify(1,0,n,alfa[c[u]],c[u]);
	}
	for(int i=h[u];i;i=e[i].p){
		int v=e[i].v;
		if(v==fa||vis[v])continue;
		calc(v,u,val);
	}
}
void dfs2(int u,int fa,int flag){
	for(int i=h[u];i;i=e[i].p){
		int v=e[i].v;
		if(v==fa||v==son[u])continue;
		dfs2(v,u,0);
	}
	if(son[u]){
		dfs2(son[u],u,1);vis[son[u]]=1;
	}
	calc(u,fa,1);vis[son[u]]=0;
	ans[u]=query(1,0,n);
	if(!flag)calc(u,fa,-1);
}
int fd[maxn];
int main(){
	//freopen("CF600E.in","r",stdin);
	//freopen("CF600E.out","w",stdout);
	n=read();
	for(int i=1;i<=n;i++){c[i]=read();if(!fd[c[i]]){sum+=c[i];fd[c[i]]=1;}}
	for(int i=1;i<n;i++){
		int a=read(),b=read();
		add(a,b);
	}
	build(1,0,n);
	dfs1(1,0);
	dfs2(1,0,1);
	for(int i=1;i<=n;i++)printf("%lld ",ans[i]);
	return 0;
}

发布了62 篇原创文章 · 获赞 1 · 访问量 1002

猜你喜欢

转载自blog.csdn.net/wmhtxdy/article/details/103766177