P2590 [ZJOI2008]树的统计

题目描述

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。

我们将以下面的形式来要求你对这棵树完成一些操作:

I. CHANGE u t : 把结点u的权值改为t

II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值

III. QSUM u v: 询问从点u到点v的路径上的节点的权值和

注意:从点u到点v的路径上的节点包括u和v本身

输入输出格式

输入格式:

 

输入文件的第一行为一个整数n,表示节点的个数。

接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。

接下来一行n个整数,第i个整数wi表示节点i的权值。

接下来1行,为一个整数q,表示操作的总数。

接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。

 

输出格式:

 

对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

 

输入输出样例

输入样例#1: 
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
输出样例#1: 
4
1
2
2
10
6
5
6
5
16

说明

对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

Solution:

  树剖的模板题。。。

  总结一波错误:若比较函数$Max$是宏定义,且比较的两个变量中含有函数,那么不要用宏定义的$Max$,因为这样函数会运行两次,白白的浪费时间。

代码:

  1 #include<bits/stdc++.h>
  2 #define il inline
  3 #define lson l,m,rt<<1
  4 #define rson m+1,r,rt<<1|1
  5 #define For(i,a,b) for(int (i)=(a);(i)<=(b);(i)++)
  6 #define Swap(a,b) ((a)^=(b),(b)^=(a),(a)^=(b))
  7 using namespace std;
  8 const int N=100005,inf=233333333;
  9 int n,q,cnt,h[N],a[N];
 10 struct node{
 11     int to,net,w;
 12 }e[N];
 13 int size[N],wson[N],fa[N],dep[N],top[N],pos[N],pre[N],tot;
 14 
 15 il int gi(){
 16     int a=0;char x=getchar();bool f=0;
 17     while((x<'0'||x>'9')&&x!='-')x=getchar();
 18     if(x=='-')x=getchar(),f=1;
 19     while(x>='0'&&x<='9')a=(a<<3)+(a<<1)+x-48,x=getchar();
 20     return f?-a:a;
 21 }
 22 
 23 il void add(int u,int v){
 24     e[++cnt].to=v,e[cnt].net=h[u],h[u]=cnt;
 25     e[++cnt].to=u,e[cnt].net=h[v],h[v]=cnt;
 26 }
 27 
 28 il void dfs1(int u,int f){
 29     size[u]=1;
 30     for(int i=h[u];i;i=e[i].net){
 31         int v=e[i].to;
 32         if(v==f)continue;
 33         dep[v]=dep[u]+1;fa[v]=u;
 34         dfs1(v,u);
 35         size[u]+=size[v];
 36         if(size[v]>size[wson[u]])wson[u]=v;
 37     }
 38 }
 39 
 40 il void dfs2(int u,int op){
 41     pos[u]=++tot;pre[tot]=u;top[u]=op;
 42     if(wson[u])dfs2(wson[u],op);
 43     for(int i=h[u];i;i=e[i].net){
 44         int v=e[i].to;
 45         if(v==fa[u]||v==wson[u])continue;
 46         dfs2(v,v);
 47     }
 48 }
 49 
 50 int sum[N<<2],maxn[N<<2];
 51 
 52 il void pushup(int rt){
 53     sum[rt]=sum[rt<<1]+sum[rt<<1|1];
 54     maxn[rt]=max(maxn[rt<<1],maxn[rt<<1|1]);
 55 }
 56 
 57 il void build(int l,int r,int rt){
 58     if(l==r){sum[rt]=maxn[rt]=a[pre[l]];return;}
 59     int m=l+r>>1;
 60     build(lson),build(rson);
 61     pushup(rt);
 62 }
 63 
 64 il void update(int k,int v,int l,int r,int rt){
 65     if(l==r){sum[rt]=maxn[rt]=v;return;}
 66     int m=l+r>>1;
 67     if(k<=m)update(k,v,lson);
 68     else update(k,v,rson);
 69     pushup(rt);
 70 }
 71 
 72 il int query1(int L,int R,int l,int r,int rt){
 73     if(L<=l&&R>=r)return sum[rt];
 74     int m=l+r>>1,ret=0;
 75     if(L<=m)ret+=query1(L,R,lson);
 76     if(R>m)ret+=query1(L,R,rson);
 77     return ret;
 78 }
 79 
 80 il int query2(int L,int R,int l,int r,int rt){
 81     if(L<=l&&R>=r)return maxn[rt];
 82     int m=l+r>>1,tmp=-inf;
 83     if(L<=m)tmp=max(tmp,query2(L,R,lson));
 84     if(R>m)tmp=max(tmp,query2(L,R,rson));
 85     return tmp;
 86 }
 87 
 88 il int getsum(int u,int v){
 89     int ans=0;
 90     while(top[u]!=top[v]){
 91         if(dep[top[u]]<dep[top[v]])Swap(u,v);
 92         ans+=query1(pos[top[u]],pos[u],1,n,1);
 93         u=fa[top[u]];
 94     }
 95     if(dep[u]<dep[v])Swap(u,v);
 96     ans+=query1(pos[v],pos[u],1,n,1);
 97     return ans;
 98 }
 99 
100 il int getmax(int u,int v){
101     int tmp=-inf;
102     while(top[u]!=top[v]){
103         if(dep[top[u]]<dep[top[v]])Swap(u,v);
104         tmp=max(tmp,query2(pos[top[u]],pos[u],1,n,1));
105         u=fa[top[u]];
106     }
107     if(dep[u]<dep[v])Swap(u,v);
108     tmp=max(tmp,query2(pos[v],pos[u],1,n,1));
109     return tmp;
110 }
111 
112 int main(){
113     n=gi();
114     int u,v;char s[10];
115     For(i,1,n-1)u=gi(),v=gi(),add(u,v);
116     For(i,1,n)a[i]=gi();
117     dep[1]=1,fa[1]=1;
118     dfs1(1,-1);dfs2(1,1);
119     build(1,n,1);
120     q=gi();
121     while(q--){
122         scanf("%s",s),u=gi(),v=gi();
123         if(s[1]=='H')update(pos[u],v,1,n,1);
124         if(s[1]=='M')printf("%d\n",getmax(u,v));
125         if(s[1]=='S')printf("%d\n",getsum(u,v));
126     }
127     return 0;
128 }

猜你喜欢

转载自www.cnblogs.com/five20/p/9163285.html
今日推荐