dsu on tree是一种处理树上不带修改,询问子树有关的信息的问题的思想,可以被称为静态链分治(这个称呼比较符合这个算法的特点)
算法实现
对于一个节点:
先递归处理轻儿子。
然后递归处理重儿子。
计算当前节点的答案,这里需要遍历所有轻儿子(本来需要遍历整个子树,因为重儿子的影响我们没有清除,所以不用遍历重儿子的子树)
如果该节点本身是轻儿子,那么就需要清除该节点内所有子树的影响。
大概模板:
void dfs1(int u,int fa){
sz[u]=1;
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa)continue;
dfs1(v,u);sz[u]+=sz[v];
if(sz[son[u]]<sz[v])son[u]=v;
}
}
void calc(int u,int fa,int val){
if(val>0){
modify(alfa[u],-1);//先减去原本的贡献
change(alfa[u],1);
modify(alfa[u],1);//加上现在的
}
else{
modify(alfa[u],-1);
change(alfa[u],-1);
modify(alfa[u],1);
}
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa||vis[v])continue;
calc(v,u,val)
}
}
void dfs2(int u,int fa,int flag){
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa||v==son[u])continue;
dfs2(v,u,0);
}
if(son[u]){
dfs2(son[u],u,1);vis[son[u]]=1;
}
calc(u,fa,1);vis[son[u]]=0;
ans[u]=query();//此处根据题目要求
if(!flag)calc(u,fa,0);
}
例题
树上数颜色
这个是最简单的应用了,alfa记录的是每种颜色的出现次数
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+5;
inline int read(){
char c=getchar();int t=0,f=1;
while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
return t*f;
}
int n,b,h[maxn],cnt,c[maxn];
struct edge{
int v,p;
}e[maxn<<1];
inline void add(int a,int b){
e[++cnt].p=h[a];
e[cnt].v=b;
h[a]=cnt;
e[++cnt].p=h[b];
e[cnt].v=a;
h[b]=cnt;
}
int sz[maxn],son[maxn];
void dfs1(int u,int fa){
sz[u]=1;
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa)continue;
dfs1(v,u);sz[u]+=sz[v];
if(sz[son[u]]<sz[v])son[u]=v;
}
}
int m,alfa[maxn],tot,ans[maxn],vis[maxn];
void calc(int u,int fa,int val){
if(val>0){
if(!alfa[c[u]])++tot;
alfa[c[u]]++;
}
else{
if(alfa[c[u]]==1)tot--;
alfa[c[u]]--;
}
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa||vis[v])continue;
calc(v,u,val);
}
}
void dfs2(int u,int fa,int flag){
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa||v==son[u])continue;
dfs2(v,u,0);
}
if(son[u]){
dfs2(son[u],u,1);vis[son[u]]=1;
}
calc(u,fa,1);vis[son[u]]=false;
ans[u]=tot;
if(!flag)calc(u,fa,-1);
}
int main(){
n=read();
for(int i=1;i<n;i++){
int a=read(),b=read();
add(a,b);
}
dfs1(1,0);
for(int i=1;i<=n;i++)c[i]=read();
dfs2(1,0,1);
m=read();
while(m--){
int x=read();
printf("%d\n",ans[x]);
}
return 0;
}
CF600E
和上题差距不大,除了记录alfa意外,还记录了beta,表示出现次数为x的颜色的编号之和
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+5;
inline int read(){
char c=getchar();int t=0,f=1;
while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
return t*f;
}
int n,c[maxn];
struct edge{
int v,p;
}e[maxn<<1];
int h[maxn],cnt;
inline void add(int a,int b){
e[++cnt].p=h[a];
e[cnt].v=b;
h[a]=cnt;
e[++cnt].p=h[b];
e[cnt].v=a;
h[b]=cnt;
}
int sz[maxn],alfa[maxn],tot,son[maxn];
void dfs1(int u,int fa){
sz[u]=1;
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa)continue;
dfs1(v,u);
sz[u]+=sz[v];if(sz[son[u]]<sz[v])son[u]=v;
}
}
long long sum,beta[maxn<<2],ans[maxn];
#define lc rt<<1
#define rc rt<<1|1
void modify(int rt,int l,int r,int x,int val){
beta[rt]+=val;
if(l==r){return ;}
int mid=l+r>>1;
if(x<=mid)modify(lc,l,mid,x,val);
else modify(rc,mid+1,r,x,val);
}
long long query(int rt,int l,int r){
if(l==r){return beta[rt];}
int mid=l+r>>1;
if(beta[rc])return query(rc,mid+1,r);
else return query(lc,l,mid);
}
int vis[maxn];
void build(int rt,int l,int r){
beta[rt]=sum;
if(l==r){return ;}
int mid=l+r>>1;
build(lc,l,mid);
}
void calc(int u,int fa,int val){
if(val>0){
modify(1,0,n,alfa[c[u]],-c[u]);
alfa[c[u]]++;
modify(1,0,n,alfa[c[u]],c[u]);
}
else{
modify(1,0,n,alfa[c[u]],-c[u]);
alfa[c[u]]--;
modify(1,0,n,alfa[c[u]],c[u]);
}
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa||vis[v])continue;
calc(v,u,val);
}
}
void dfs2(int u,int fa,int flag){
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa||v==son[u])continue;
dfs2(v,u,0);
}
if(son[u]){
dfs2(son[u],u,1);vis[son[u]]=1;
}
calc(u,fa,1);vis[son[u]]=0;
ans[u]=query(1,0,n);
if(!flag)calc(u,fa,-1);
}
int fd[maxn];
int main(){
//freopen("CF600E.in","r",stdin);
//freopen("CF600E.out","w",stdout);
n=read();
for(int i=1;i<=n;i++){c[i]=read();if(!fd[c[i]]){sum+=c[i];fd[c[i]]=1;}}
for(int i=1;i<n;i++){
int a=read(),b=read();
add(a,b);
}
build(1,0,n);
dfs1(1,0);
dfs2(1,0,1);
for(int i=1;i<=n;i++)printf("%lld ",ans[i]);
return 0;
}