BZOJ 4034 [HAOI2015]树上操作 线段树+树剖或dfs

题意

直接照搬原题面

有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个

操作,分为三种:

操作 1 :把某个节点 x 的点权增加 a 。

操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。

操作 3 :询问某个节点 x 到根的路径中所有点的点权和。

分析

先树剖一下,按重新编号的点建线段树

  • 操作1:直接单点修改

  • 操作2:一个子树里的点的编号是连在一起的,直接区间修改

  • 操作3:该点的\(top\)不为1时,即该点跟根结点不在一条链上,加上这条链的贡献(线段树的区间求和),

    再跳到\(top\)的父节点所在链,直到\(top\)为1再加上\(top\)为1这条链的贡献,就能求出1到x的答案了

其实还有另一种不用树剖的做法,用线段树维护前缀和,\(a[x]\)为从\(1\)\(x\)的点权和,操作1就等于区间修改\(x\)的子树中所有节点,

操作2就等于对\(x\)的子树中每个节点进行一次操作1,这肯定不行,考虑单个节点的贡献,每个节点总共增加的值为它在\(x\)的子树中的深度\(p\)

乘上增加量\(k\),区间贡献和即为区间深度之和乘\(k\).

线段树要多记录区间结点的深度和\(w[p]\),区间修改的式子为\(val[p]+=w[p]*k-(r-l+1)*dep*k\)\(dep\)\(x\)的父节点的深度

加个lazy标记记录\(dep*k​\)就行了

Code 1

#include<bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define lson l,mid,p<<1
#define rson mid+1,r,p<<1|1
using namespace std;
typedef long long ll;
const int inf=1e9;
const int maxn=3e5+10;
int n,q;
ll a[maxn];
vector<int>g[maxn];
int top[maxn],in[maxn],out[maxn],sz[maxn],f[maxn],son[maxn],id[maxn],tot;
ll val[maxn<<2],tag[maxn<<2];
void pp(int p){val[p]=val[p<<1]+val[p<<1|1];}
void pd(int l,int r,int p,ll k){val[p]+=(r-l+1)*k,tag[p]+=k;}
void bd(int l,int r,int p){
    if(l==r) return val[p]=a[id[l]],void();
    int mid=l+r>>1;
    bd(lson);bd(rson);pp(p);
}
void up(int dl,int dr,int l,int r,int p,ll k){
    if(l>=dl&&r<=dr){
        val[p]+=(r-l+1)*k;tag[p]+=k;
        return;
    }int mid=l+r>>1;
    pd(lson,tag[p]);pd(rson,tag[p]);tag[p]=0;
    if(dl<=mid) up(dl,dr,lson,k);
    if(dr>mid) up(dl,dr,rson,k);
    pp(p);
}
ll qy(int dl,int dr,int l,int r,int p){
    if(l>=dl&&r<=dr) return val[p];
    int mid=l+r>>1;ll ret=0;
    pd(lson,tag[p]);pd(rson,tag[p]);tag[p]=0;
    if(dl<=mid) ret+=qy(dl,dr,lson);
    if(dr>mid) ret+=qy(dl,dr,rson);
    return ret;
}
void dfs1(int u){
    sz[u]=1;
    for(int i=0;i<g[u].size();i++){
        int x=g[u][i];
        if(x==f[u]) continue;
        f[x]=u;dfs1(x);
        sz[u]+=sz[x];
        if(sz[son[u]]<sz[x]) son[u]=x;
    }
}
void dfs2(int u,int t){
    top[u]=t;in[u]=++tot;id[tot]=u;
    if(son[u]) dfs2(son[u],t);
    for(int i=0;i<g[u].size();i++){
        int x=g[u][i];
        if(x==f[u]||x==son[u]) continue;
        dfs2(x,x);
    }
    out[u]=tot;
}
ll cal(int x){
    ll res=0;
    while(top[x]!=1){
        res+=qy(in[top[x]],in[x],1,n,1);
        x=f[top[x]];
    }
    res+=qy(1,in[x],1,n,1);return res;
}
int main(){
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;i++){
        scanf("%lld",&a[i]);
    }
    for(int i=1,a,b;i<n;i++){
        scanf("%d%d",&a,&b);
        g[a].pb(b);g[b].pb(a);
    }
    dfs1(1);dfs2(1,1);bd(1,n,1);
    while(q--){
        int op,x;ll a;
        scanf("%d%d",&op,&x);
        if(op==1){
            scanf("%lld",&a);
            up(in[x],in[x],1,n,1,a);
        }else if(op==2){
            scanf("%lld",&a);
            up(in[x],out[x],1,n,1,a);
        }else{
            printf("%lld\n",cal(x));
        }
    }
    return 0;
}

Code 2

#include<bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define lson l,mid,p<<1
#define rson mid+1,r,p<<1|1
using namespace std;
typedef long long ll;
const int inf=1e9;
const int maxn=3e5+10;
int n,q;
int d[maxn];
ll a[maxn],dep[maxn];
vector<int>g[maxn];
int f[maxn],in[maxn],out[maxn],tot;
ll val[maxn<<2],tag[maxn<<2],w[maxn<<2],tw[maxn<<2],qw[maxn<<2];
void pushup(int p){
    val[p]=val[p<<1]+val[p<<1|1];
}
void tag1(int l,int r,int p,ll k,ll tk,ll qk){
    val[p]+=w[p]*k-(r-l+1)*tk+(r-l+1)*qk;tag[p]+=k;
    tw[p]+=tk;qw[p]+=qk;
}
void bd(int l,int r,int p){
    if(l==r){
        val[p]=a[d[l]];
        w[p]=dep[d[l]];
        return;
    }
    int mid=l+r>>1;
    bd(lson);bd(rson);
    w[p]=w[p<<1]+w[p<<1|1];
    pushup(p);
}
void up(int dl,int dr,int l,int r,int p,ll k,ll dep){
    if(l>=dl&&r<=dr){
        val[p]+=(w[p]-(r-l+1)*dep)*k;tag[p]+=k;
        tw[p]+=dep*k;
        return;
    }int mid=l+r>>1;
    tag1(lson,tag[p],tw[p],qw[p]);tag1(rson,tag[p],tw[p],qw[p]);tag[p]=0;tw[p]=0;qw[p]=0;
    if(dl<=mid) up(dl,dr,lson,k,dep);
    if(dr>mid) up(dl,dr,rson,k,dep);
    pushup(p);
}
void upd(int dl,int dr,int l,int r,int p,ll k){
    if(l>=dl&&r<=dr){
        val[p]+=(r-l+1)*k;qw[p]+=k;
        return;
    }int mid=l+r>>1;
    tag1(lson,tag[p],tw[p],qw[p]);tag1(rson,tag[p],tw[p],qw[p]);tag[p]=0;tw[p]=0;qw[p]=0;
    if(dl<=mid) upd(dl,dr,lson,k);
    if(dr>mid) upd(dl,dr,rson,k);
    pushup(p);
}
ll qy(int dl,int dr,int l,int r,int p){
    ll ret=0;
    if(l>=dl&&r<=dr) return val[p];int mid=l+r>>1;
    tag1(lson,tag[p],tw[p],qw[p]);tag1(rson,tag[p],tw[p],qw[p]);tag[p]=0;tw[p]=0;qw[p]=0;
    if(dl<=mid) ret+=qy(dl,dr,lson);
    if(dr>mid) ret+=qy(dl,dr,rson);
    return ret;
}
void dfs(int u){
    in[u]=++tot;d[tot]=u;dep[u]=dep[f[u]]+1;
    for(int i=0;i<g[u].size();i++){
        int x=g[u][i];
        if(x==f[u]) continue;
        f[x]=u;a[x]+=a[u];
        dfs(x);
    }
    out[u]=tot;
}
int main(){
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;i++){
        scanf("%lld",&a[i]);
    }
    for(int i=1,a,b;i<n;i++){
        scanf("%d%d",&a,&b);
        g[a].pb(b);g[b].pb(a);
    }
    dfs(1);
    bd(1,n,1);
    while(q--){
        int op,x;
        ll a;
        scanf("%d%d",&op,&x);
        if(op!=3) scanf("%lld",&a);
        if(op==1) upd(in[x],out[x],1,n,1,a);
        else if(op==2) up(in[x],out[x],1,n,1,a,dep[f[x]]);
        else printf("%lld\n",qy(in[x],in[x],1,n,1));
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/xyq0220/p/10911425.html