「ZJOI2008」 树的统计 - 树链剖分

题目大意

给定一棵无根树,每个点有一个权值,有三种操作:修改权值,询问x到y的路径上节点的最大权值和权值和。输出询问的结果。

分析

树剖模板题。由于是无根树,所以任意取一个节点作为根都是没有问题的,然后记size(u)为以u为根的子树的节点数,若vu的儿子中size值最大的一个节点,则称边(u,v)为重边,vu的重儿子。u到其他儿子的边称为轻边。为了处理u,v之间的路径,可以分别处理到它们的LCA的路径。对于由重边或者一个节点组成的重路径,可以用线段树等数据结构维护;对于轻边,可以直接跳过,因为轻边的两节点一定在某条重路径上。轻重边的剖分可以用两次Dfs实现。在剖分过程中需要求出以下的值:

  • prt[x]:x的父亲
  • dep[x]:x的深度
  • size[x]:x的子树节点数
  • son[x]:x的重儿子
  • top[x]:x所在重路径的顶部节点
  • seg[x]:x在线段树中的下标
  • rev[x]:在线段树中第x个位置所对应树中的节点编号

求出来之后,用线段树维护summax即可。具体实现看代码。

代码

#include <cstdio>
#include <iostream>
using namespace std;
#define N 30005
struct Edge {
	long long to,next;
}e[N*2];
int h[N],cnt;
int prt[N],dep[N],size[N],son[N];
int top[N],seg[N],rev[N*4];
int sum[N*4],maxa[N*4],val[N];
int n,m,ans_max,ans_sum;
void add(int x,int y) {
	e[++cnt]=(Edge){y,h[x]};
	h[x]=cnt;
}
void Dfs1(int x,int fa) {//第1遍Dfs,求prt,size,dep和son
	prt[x]=fa;
	size[x]=1;
	dep[x]=dep[fa]+1;
	for (int i=h[x];i;i=e[i].next) {
		int y=e[i].to;
		if (y==fa) continue;
		Dfs1(y,x);
		size[x]+=size[y];
		if (size[y]>size[son[x]]) son[x]=y;
	}
}
void Dfs2(int x,int prt) {//第2遍Dfs,求seg,top,rev
	if (son[x]) {
		seg[son[x]]=++seg[0];
		top[son[x]]=top[x];
		rev[seg[0]]=son[x];
		Dfs2(son[x],x);
	}
	for (int i=h[x];i;i=e[i].next) {
		int y=e[i].to;
		if (top[y]) continue;
		seg[y]=++seg[0];
		rev[seg[0]]=y;
		top[y]=y;
		Dfs2(y,x);
	}
}
void build(int p,int l,int r) {//建树
	if (l==r) {
		maxa[p]=sum[p]=val[rev[l]];
		return;
	}
	int mid=(l+r)>>1;
	build(p<<1,l,mid);
	build(p<<1|1,mid+1,r);
	maxa[p]=max(maxa[p<<1],maxa[p<<1|1]);
	sum[p]=sum[p<<1]+sum[p<<1|1];
}
void query(int p,int l,int r,int x,int y) {//区间查询
	if (y<l||x>r) return;
	if (x<=l&&r<=y) {
		ans_max=max(ans_max,maxa[p]);
		ans_sum+=sum[p];
		return;
	}
	int mid=(l+r)>>1;
	query(p<<1,l,mid,x,y);
	query(p<<1|1,mid+1,r,x,y);
}
void update(int p,int l,int r,int x,int y) {//单点修改
	if (x<l||x>r) return;
	if (l==r) {
		sum[p]=y;
		maxa[p]=y;
		return;
	}
	int mid=(l+r)>>1;
	update(p<<1,l,mid,x,y);
	update(p<<1|1,mid+1,r,x,y);
	sum[p]=sum[p<<1]+sum[p<<1|1];
	maxa[p]=max(maxa[p<<1],maxa[p<<1|1]);
}
void ask(int x,int y) {//询问x到y路径上的sum或max
	int fx=top[x],fy=top[y];
	while (fx!=fy) {
		if (dep[fx]<dep[fy]) {//类似LCA向上跳的方法,每次选择深度最大的往上跳
			swap(x,y);
			swap(fx,fy);
		}
		query(1,1,seg[0],seg[fx],seg[x]);//对于重路径,直接查询
		x=prt[fx];
		fx=top[x];
	}
	if (dep[x]>dep[y]) swap(x,y);
	query(1,1,seg[0],seg[x],seg[y]);
}
int read() {
	char ch=getchar();
	int s=0,f=1;
	while (ch<'0'||ch>'9') f=(ch=='-'?-1:1),ch=getchar();
	while (ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
	return f*s;
}
int main() {
	n=read();
	for (int i=1;i<n;i++) {
		int a,b;
		a=read(),b=read();
		add(a,b);
		add(b,a);
	}
	for (int i=1;i<=n;i++)
		val[i]=read();
	Dfs1(1,0);
	seg[0]=seg[1]=rev[1]=top[1]=1;
	Dfs2(1,0);
	build(1,1,seg[0]);
	m=read();
	while (m--) {
		char op[10];
		int x,y;
		scanf("%s",op);x=read(),y=read();
		if (op[0]=='C') {
			update(1,1,seg[0],seg[x],y);
		} else {
			ans_max=-(1<<30);
			ans_sum=0;
			ask(x,y);
			if (op[1]=='M') {
				printf("%d\n",ans_max);
			} else {
				printf("%d\n",ans_sum);
			}
		}
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/sin_Yang/article/details/82384399