版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/chimchim04/article/details/89670959
题目描述
这是一道模板题。
给一棵有根树,这棵树由编号为1...N的N个结点组成。根结点的编号为R。每个结点都有一个权值,结点i的权值为Vi 。
接下来有 组操作,操作分为两类:
- 1 a x,表示将结点a 的权值增加 x;
- 2 a,表示求结点a的子树上所有结点的权值之和。
输入
第一行有三个整数N,M和R 。
第二行有 N个整数,第i 个整数表示Vi。
在接下来的N-1 行中,每行两个整数,表示一条边。
在接下来的M 行中,每行一组操作。
输出
对于每组 2 a 操作,输出一个整数,表示「以结点 为根的子树」上所有结点的权值之和。
样例输入 Copy
10 14 9
12 -6 -4 -3 12 8 9 6 6 2
8 2
2 10
8 6
2 7
7 1
6 3
10 9
2 4
10 5
1 4 -1
2 2
1 7 -1
2 10
1 10 5
2 1
1 7 -5
2 5
1 1 8
2 7
1 8 8
2 2
1 5 5
2 6
样例输出 Copy
21
34
12
12
23
31
4
提示
N,M<=1e6,R<=N,-1e6<=Vi ,x<=1e6
来源/分类
思路:按dfs遍历序重新给节点标号,每个节点的子树节点就能用区间表示 ,接下来就可以用线段树,树状数组,zkw线段树,解决单点更新区间求和的问题
线段树代码:
//Time:1372 ms
//Memory:71392 kb
#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
#define ll long long
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
const int N=1000005;
int n,tot,first[N],a[N],to,b[N],d[N],di;
ll tr[N<<2],c[N];
struct node
{
int v,nex;
}e[N<<1];
void add(int u,int v)
{
e[tot].v=v;
e[tot].nex=first[u];
first[u]=tot++;
}
int dfs(int rt)
{
d[to]=rt;
a[rt]=to++;
int ans=a[rt];
for(int i=first[rt];i!=-1;i=e[i].nex)
{
int v=e[i].v;
if(a[v]) continue;
ans=max(ans,dfs(v));
}
b[rt]=ans;
// printf("%d %d %d\n",rt,a[rt],b[rt]);
return ans;
}
void pushup(int rt)
{
tr[rt]=tr[rt<<1]+tr[rt<<1|1];
}
void build(int l,int r,int rt)
{
if(l==r)
{
tr[rt]=c[d[di++]];
//printf("%d %d %lld\n",rt,d[di-1],c[d[di-1]]);
return ;
}
int m=(l+r)>>1;
build(lson);
build(rson);
pushup(rt);
}
void updata(int p,int d,int l,int r,int rt)
{
if(l==r)
{
tr[rt]+=d;
return;
}
int m=(l+r)/2;
if(p<=m) updata(p,d,lson);
else updata(p,d,rson);
pushup(rt);
}
ll query(int L,int R,int l,int r,int rt)
{
if(L<=l&&r<=R) return tr[rt];
int m=(l+r)/2;
ll ret=0;
if(L<=m) ret+=query(L,R,lson);
if(R>m) ret+=query(L,R,rson);
return ret;
}
int main()
{
int m,RR;
scanf("%d%d%d",&n,&m,&RR);
for(int i=1;i<=n;i++) scanf("%lld",&c[i]);
int x,y,z;
memset(first,-1,sizeof(first));
tot=0,to=1;
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs(RR);
di=1;
build(1,n,1);
for(int i=0;i<m;i++)
{
scanf("%d",&z);
if(z==1)
{
scanf("%d%d",&x,&y);
updata(a[x],y,1,n,1);
}
else
{
scanf("%d",&x);
printf("%lld\n",query(a[x],b[x],1,n,1));
}
}
return 0;
}
树状数组代码:
//Time:1268 ms
//Memory:63580 kb
#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
#define ll long long
const int N=1000005;
int n,tot,first[N],a[N],to,b[N],c[N];
ll tr[N<<2];
struct node
{
int v,nex;
}e[N<<1];
void adde(int u,int v)
{
e[tot].v=v;
e[tot].nex=first[u];
first[u]=tot++;
}
int lowbit(int x)
{
return x&-x;
}
void add(int k,int x)
{
while(k<=n)
{
tr[k]=tr[k]+(ll)x;
k+=lowbit(k);
}
}
ll sum(int x)
{
ll sum=0;
while(x>0)
{
sum+=tr[x];
x-=lowbit(x);
}
return sum;
}
int dfs(int rt)
{
add(to,c[rt]);
a[rt]=to++;
int ans=a[rt];
for(int i=first[rt];i!=-1;i=e[i].nex)
{
int v=e[i].v;
if(a[v]) continue;
ans=max(ans,dfs(v));
}
b[rt]=ans;
return ans;
}
int main()
{
int m,RR;
scanf("%d%d%d",&n,&m,&RR);
for(int i=1;i<=n;i++) scanf("%d",&c[i]);
int x,y,z;
memset(first,-1,sizeof(first));
tot=0,to=1;
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&x,&y);
adde(x,y);
adde(y,x);
}
dfs(RR);
for(int i=0;i<m;i++)
{
scanf("%d",&z);
if(z==1)
{
scanf("%d%d",&x,&y);
add(a[x],y);
}
else
{
scanf("%d",&x);
printf("%lld\n",sum(b[x])-sum(a[x]-1));
}
}
return 0;
}
zkw线段树:
//Time:1081 ms
//Memory:71392 kb
#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
#define ll long long
const int N=1000005;
int n,tot,first[N],a[N],to,b[N],d[N],di,M;
ll tr[N<<2],c[N];
struct node
{
int v,nex;
} e[N<<1];
void add(int u,int v)
{
e[tot].v=v;
e[tot].nex=first[u];
first[u]=tot++;
}
int dfs(int rt)
{
d[to]=rt;
a[rt]=to++;
int ans=a[rt];
for(int i=first[rt]; i!=-1; i=e[i].nex)
{
int v=e[i].v;
if(a[v]) continue;
ans=max(ans,dfs(v));
}
b[rt]=ans;
return ans;
}
void pushup(int rt)
{
tr[rt]=tr[rt<<1]+tr[rt<<1|1];
}
void build()
{
for(M=1; M<n; M<<=1);
for(int i=M+1; i<=M+n; i++) tr[i]=c[d[di++]];
for(int i=M; i>=1; i--) pushup(i);
}
void updata(int x,int y)
{
for(tr[x+=M]+=y,x>>=1; x; x>>=1)
pushup(x);
}
ll query(int l,int r)
{
ll ans=0;
for(l=l+M-1,r=r+M+1; l^r^1; l>>=1,r>>=1)
{
if(l&1^1) ans+=tr[l^1]; //l+1
if(r&1) ans+=tr[r^1]; //r-1
}
return ans;
}
int main()
{
int m,RR;
scanf("%d%d%d",&n,&m,&RR);
for(int i=1; i<=n; i++) scanf("%lld",&c[i]);
int x,y,z;
memset(first,-1,sizeof(first));
tot=0,to=1;
for(int i=0; i<n-1; i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs(RR);
di=1;
build();
for(int i=0; i<m; i++)
{
scanf("%d",&z);
if(z==1)
{
scanf("%d%d",&x,&y);
updata(a[x],y);
}
else
{
scanf("%d",&x);
printf("%lld\n",query(a[x],b[x]));
}
}
return 0;
}