【GDOI2016】疯狂动物城(树链剖分+可持久化线段树)

版权声明:转载请声明 https://blog.csdn.net/ezoiHQM/article/details/81675310

码农题……
调了我三个晚上……
看来我的代码能力还是太弱了……
首先我们不难发现在u到v这条链的答案为 i = 1 n ( n i ) ( n i + 1 ) a i 2
然后把它拆开可以得到答案为 i = 1 n ( n + 1 ) n a i ( 2 n + 1 ) i a i + i 2 a i
对链进行操作自然就是树剖啦
对于第一部分就是普通的树剖
第二部分,我们对线段树的每一个节点维护 s u m r = i = 1 n i a i s u m l = i = 1 n ( n i + 1 ) a i
信息合并:

s u m r [ o ] = s u m r [ l c h [ o ] ] + s u m r [ r c h [ o ] ] + s u m [ r c h [ o ] ] s i z [ l c h [ o ] ]

s u m l 同理
其中 s u m 表示当前子树所表示的区间的权值和
懒标记就加上 t a g [ o ] s i z [ o ] ( s i z [ o ] + 1 ) 2
第三部分,我们对线段树的每一个节点维护 s s u m r = i = 1 n i 2 a i s s u m l = i = 1 n ( n i + 1 ) 2 a i
信息合并:
s s u m r [ o ] = s s u m r [ l c h [ o ] ] + s s u m [ r c h [ o ] ] + s u m [ r c h [ o ] ] s i z [ l c h [ o ] ] 2 + 2 s u m r [ r c h [ o ] ] s i z [ l c h [ o ] ]

s u m l 同理
懒标记就加上 t a g [ o ] s i z [ o ] ( s i z [ o ] + 1 ) ( 2 s i z [ o ] + 1 ) 6
树剖的时候的信息合并同理。
再加上一个持久化就好了。
为了避免爆空间,加个标记永久化就好了。
然后就是考验代码能力的时候了!!!!!!!!
如果有误在评论区吼一声哦!
代码:

#include<cstdio>
#include<cctype>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const ll mod=20160501,inv2=10080251;
const int N=1e5+10;
int n,m,tot,cnt,tcnt,bcnt,now,lastans,head[N],to[N<<1],nxt[N<<1],f[N],dep[N],tp[N],son[N],siz[N],rnk[N],dfn[N],val[N],rt[N],ch[N*100][2],idx[N*100];
ll sum[N*100][5],tag[N*100];
void add_edge(int u,int v){
    nxt[++tot]=head[u];
    to[tot]=v;
    head[u]=tot;
    return;
}
void maintain(int o,int l,int r){
    int mid=(l+r)>>1;
    sum[o][0]=(sum[ch[o][0]][0]+sum[ch[o][1]][0]+sum[ch[o][1]][2]*(mid-l+1)+1ll*(r-l+2)*(r-l+1)/2%mod*tag[o])%mod;
    sum[o][1]=(sum[ch[o][0]][1]+sum[ch[o][1]][1]+sum[ch[o][0]][2]*(r-mid)+1ll*(r-l+2)*(r-l+1)/2%mod*tag[o])%mod;
    sum[o][2]=(sum[ch[o][0]][2]+sum[ch[o][1]][2]+tag[o]*(r-l+1))%mod;
    sum[o][3]=(sum[ch[o][0]][3]+sum[ch[o][1]][3]+sum[ch[o][1]][2]*(mid-l+1)*(mid-l+1)+sum[ch[o][1]][0]*2*(mid-l+1)+1ll*(r-l+1)*(r-l+2)*(2*r-2*l+3)/6%mod*tag[o])%mod;
    sum[o][4]=(sum[ch[o][0]][4]+sum[ch[o][1]][4]+sum[ch[o][0]][2]*(r-mid)%mod*(r-mid)+sum[ch[o][0]][1]*2*(r-mid)+1ll*(r-l+1)*(r-l+2)*(2*r-2*l+3)/6%mod*tag[o])%mod;
    return;
}
void build(int &o,int l,int r){
    o=++tcnt;
    if(l==r){
        sum[o][0]=sum[o][1]=sum[o][2]=sum[o][3]=sum[o][4]=val[rnk[l]];
        return;
    }
    int mid=(l+r)>>1;
    build(ch[o][0],l,mid);
    build(ch[o][1],mid+1,r);
    maintain(o,l,r);
    return;
}
void dfs1(int u){
    siz[u]=1;
    for(int i=head[u];~i;i=nxt[i]){
        int v=to[i];
        if(v==f[u])
            continue;
        f[v]=u;
        dep[v]=dep[u]+1;
        dfs1(v);
        siz[u]+=siz[v];
        if(siz[v]>siz[son[u]])
            son[u]=v;
    }
    return;
}
void dfs2(int u,int tpx){
    rnk[dfn[u]=++cnt]=u;
    tp[u]=tpx;
    if(!son[u])
        return;
    dfs2(son[u],tpx);
    for(int i=head[u];~i;i=nxt[i]){
        int v=to[i];
        if(v!=f[u]&&v!=son[u])
            dfs2(v,v);
    }
    return;
}
ll query(int o,int l,int r,int L,int R,int k){
    if(!o)
        return 0;
    if(L==l&&r==R)
        return sum[o][k];
    int mid=(l+r)>>1;
    if(R<=mid)
        return (query(ch[o][0],l,mid,L,R,k)+tag[o]*(k==2?(R-L+1):(k<2?1ll*(R-L+1)*(R-L+2)/2%mod:1ll*(R-L+1)*(R-L+2)*(2*R-2*L+3)/6%mod)))%mod;
    else if(L>mid)
        return (query(ch[o][1],mid+1,r,L,R,k)+tag[o]*(k==2?(R-L+1):(k<2?1ll*(R-L+1)*(R-L+2)/2%mod:1ll*(R-L+1)*(R-L+2)*(2*R-2*L+3)/6%mod)))%mod;
    if(!k)
        return (query(ch[o][0],l,mid,L,mid,k)+query(ch[o][1],mid+1,r,mid+1,R,2)*(mid-L+1)+query(ch[o][1],mid+1,r,mid+1,R,k)+1ll*(R-L+1)*(R-L+2)/2%mod*tag[o])%mod;
    if(k==1)
        return (query(ch[o][0],l,mid,L,mid,k)+query(ch[o][0],l,mid,L,mid,2)*(R-mid)+query(ch[o][1],mid+1,r,mid+1,R,k)+1ll*(R-L+1)*(R-L+2)/2%mod*tag[o])%mod;
    if(k==3)
        return (query(ch[o][0],l,mid,L,mid,k)+query(ch[o][1],mid+1,r,mid+1,R,k)+query(ch[o][1],mid+1,r,mid+1,R,2)*(mid-L+1)%mod*(mid-L+1)+query(ch[o][1],mid+1,r,mid+1,R,0)*2*(mid-L+1)+1ll*(R-L+1)*(R-L+2)*(2*R-2*L+3)/6%mod*tag[o])%mod;
    if(k==4)
        return (query(ch[o][0],l,mid,L,mid,k)+query(ch[o][1],mid+1,r,mid+1,R,k)+query(ch[o][0],l,mid,L,mid,2)*(R-mid)%mod*(R-mid)+query(ch[o][0],l,mid,L,mid,1)*2*(R-mid)+1ll*(R-L+1)*(R-L+2)*(2*R-2*L+3)/6%mod*tag[o])%mod;
    return (query(ch[o][0],l,mid,L,mid,k)+query(ch[o][1],mid+1,r,mid+1,R,k)+tag[o]*(R-L+1))%mod;
}
void update(int &o,int p,int l,int r,int L,int R,ll v){
    if(idx[o]!=bcnt){
        o=++tcnt;
        idx[o]=bcnt;
        ch[o][0]=ch[p][0];
        ch[o][1]=ch[p][1];
        tag[o]=tag[p];
        sum[o][0]=sum[p][0];
        sum[o][1]=sum[p][1];
        sum[o][2]=sum[p][2];
        sum[o][3]=sum[p][3];
        sum[o][4]=sum[p][4];
    }
    if(L<=l&&r<=R){
        tag[o]=(tag[o]+v)%mod;
        sum[o][0]=(sum[o][0]+1ll*(r-l+1)*(r-l+2)/2%mod*v)%mod;
        sum[o][1]=(sum[o][1]+1ll*(r-l+1)*(r-l+2)/2%mod*v)%mod;
        sum[o][2]=(sum[o][2]+v*(r-l+1))%mod;
        sum[o][3]=(sum[o][3]+1ll*(r-l+1)*(r-l+2)*(2*r-2*l+3)/6%mod*v)%mod;
        sum[o][4]=(sum[o][4]+1ll*(r-l+1)*(r-l+2)*(2*r-2*l+3)/6%mod*v)%mod;
        return;
    }
    int mid=(l+r)>>1;
    if(L<=mid)
        update(ch[o][0],ch[p][0],l,mid,L,R,v);
    if(R>mid)
        update(ch[o][1],ch[p][1],mid+1,r,L,R,v);
    maintain(o,l,r);
    return;
}
ll Query(int x,int y){
    ll ret1=0,ret2=0,ret3=0,ret4=0,ret5=0,sum2=0;
    int siz1=0,siz3=0;
    while(tp[x]!=tp[y]){
        if(dep[tp[x]]>dep[tp[y]]){
            const ll s=query(rt[now],1,n,dfn[tp[x]],dfn[x],2),ss=query(rt[now],1,n,dfn[tp[x]],dfn[x],1);
            ret4=(ret4+query(rt[now],1,n,dfn[tp[x]],dfn[x],4)+s*siz1%mod*siz1+ss*2*siz1)%mod;
            ret1=(ret1+ss+s*siz1)%mod;
            ret3=(ret3+s)%mod;
            siz1+=dfn[x]-dfn[tp[x]]+1;
            siz3+=dfn[x]-dfn[tp[x]]+1;
            x=f[tp[x]];
        }
        else{
            const ll s=query(rt[now],1,n,dfn[tp[y]],dfn[y],2);
            ret5=((ret5+query(rt[now],1,n,dfn[tp[y]],dfn[y],3)+sum2*(dfn[y]-dfn[tp[y]]+1)%mod*(dfn[y]-dfn[tp[y]]+1)+ret2*2*(dfn[y]-dfn[tp[y]]+1)))%mod;
            ret2=(ret2+(dfn[y]-dfn[tp[y]]+1)*sum2+query(rt[now],1,n,dfn[tp[y]],dfn[y],0))%mod;
            ret3=(ret3+s)%mod;
            sum2=(sum2+s)%mod;
            siz3+=dfn[y]-dfn[tp[y]]+1;
            y=f[tp[y]];
        }
    }
    if(dep[x]<dep[y]){
        const ll s=query(rt[now],1,n,dfn[x],dfn[y],2);
        ret5=((ret5+query(rt[now],1,n,dfn[x],dfn[y],3)+sum2*(dfn[y]-dfn[x]+1)%mod*(dfn[y]-dfn[x]+1)+ret2*2*(dfn[y]-dfn[x]+1)))%mod;
        ret2=(ret2+(dfn[y]-dfn[x]+1)*sum2+query(rt[now],1,n,dfn[x],dfn[y],0))%mod;
        ret3=(ret3+s)%mod;
        sum2=(sum2+s)%mod;
        siz3+=dfn[y]-dfn[x]+1;
    }
    else{
        const ll s=query(rt[now],1,n,dfn[y],dfn[x],2),ss=query(rt[now],1,n,dfn[y],dfn[x],1);
        ret4=(ret4+query(rt[now],1,n,dfn[y],dfn[x],4)+s*siz1%mod*siz1+ss*2*siz1)%mod;
        ret1=(ret1+ss+s*siz1)%mod;
        ret3=(ret3+s)%mod;
        siz1+=dfn[x]-dfn[y]+1;
        siz3+=dfn[x]-dfn[y]+1;
    }
    ret4=(ret4+ret5+sum2*siz1%mod*siz1+ret2*2*siz1)%mod;
    ret1=(ret1+sum2*siz1+ret2)%mod;
    return ((1ll*siz3*(siz3+1)%mod*ret3-(2*siz3+1)*ret1+ret4)%mod*inv2%mod+mod)%mod;
}
void Update(int x,int y,int v){
    ++bcnt;
    while(tp[x]!=tp[y]){
        if(dep[tp[x]]<dep[tp[y]])
            swap(x,y);
        update(rt[bcnt],rt[now],1,n,dfn[tp[x]],dfn[x],v);
        x=f[tp[x]];
    }
    if(dep[x]>dep[y])
        swap(x,y);
    update(rt[bcnt],rt[now],1,n,dfn[x],dfn[y],v);
    now=bcnt;
    return;
}
int rd(){
    int x=0;
    char c;
    do c=getchar();
    while(!isdigit(c));
    do{
        x=(x<<1)+(x<<3)+(c^48);
        c=getchar();
    }while(isdigit(c));
    return x;
}
void InitInput(){
    memset(head,-1,sizeof(head));
    n=rd();
    m=rd();
    for(int i=1;i<n;i++){
        int u=rd(),v=rd();
        add_edge(u,v);
        add_edge(v,u);
    }
    for(int i=1;i<=n;i++)
        val[i]=rd();
    dfs1(1);
    dfs2(1,1);
    build(rt[0],1,n);
    return;
}
void Ask(){
    while(m--){
        int opt=rd();
        if(opt==1){
            int x=rd()^lastans,y=rd()^lastans,delta=rd();
            Update(x,y,delta);
        }
        else if(opt==2){
            int x=rd()^lastans,y=rd()^lastans;
            lastans=Query(x,y);
            printf("%d\n",lastans);
        }
        else{
            int x=rd()^lastans;
            now=x;
        }
    }
    return;
}
int main(){
    InitInput();
    Ask();
    return 0;
}

猜你喜欢

转载自blog.csdn.net/ezoiHQM/article/details/81675310