牛客 树上路径(树链)

链接: https://www.nowcoder.com/acm/contest/180/E

思路: 对于前两个操作就是最基本操作,那么问题就在于第三个操作,可以发现第三个操作就是求a1*(a2+a3+...+ an-1+an)+ a2*(a3+ a4+...+ an) +a3*( a4+...+an) +...+an-1*an;

那么转化一下就是[(a1+a2+a3+...+an)*(a1+a2+a3+...+an)-(a1*a1+a2*a2+...+an*an) ]/2;

那么就维护  sum  和 二次方sum就可以了。

代码:

#include<bits/stdc++.h>
#define lson (i<<1)
#define rson (i<<1|1)

using namespace std;
typedef long long ll;
const int N =1e5+5;
const ll mod=1e9+7;
const ll inv2=500000004;

struct eee
{
    int v;
    int next;
}edge[N*2];

int tot,head[N];

struct node
{
    int l,r;
    ll sum1;
    ll sum2;
    ll lz;
}tr[N<<2];

int fat[N]; /// 当前节点的直接父亲
int dep[N]; /// 当前节点的在树上深度
int siz[N]; /// 当前节点的孩子个数
int son[N]; /// 当前节点的重孩子
int rak[N]; /// 线段树的第i个节点是?
int top[N]; /// 当前节点的链开始节点 top
int idd[N]; /// x在线段树中第几个节点
int cnt;

int n,m;
ll a[N];

void init()
{
    tot=0;
    cnt=0;
    memset(head,-1,sizeof(head));
    memset(son,0,sizeof(son));
    memset(siz,0,sizeof(siz));
}

void add(int u,int v)
{
    edge[++tot].v=v; edge[tot].next=head[u]; head[u]=tot;
}

void dfs1(int u,int fa,int deep)
{
    fat[u]=fa;
    dep[u]=deep;
    siz[u]=1;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(v==fa) continue;
        dfs1(v,u,deep+1);
        siz[u]+=siz[v];
        if(son[u]==0||siz[v]>siz[son[u]]){
            son[u]=v;
        }
    }
}

void dfs2(int u,int t)
{
    top[u]=t;
    idd[u]=++cnt;
    rak[cnt]=u;
    if(!son[u]) return ;
    dfs2(son[u],t);

    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        //if(v==fat[u]) continue;
        if(v!=son[u]&&v!=fat[u]){
            dfs2(v,v);
        }
    }
}

void push_up(int i)
{
    tr[i].sum1=(tr[lson].sum1+tr[rson].sum1)%mod;
    tr[i].sum2=(tr[lson].sum2+tr[rson].sum2)%mod;
}

void build(int i,int l,int r)
{
    tr[i].l=l; tr[i].r=r; tr[i].sum1=tr[i].sum2=0;
    if(l==r){
        tr[i].sum1=a[rak[l]];
        tr[i].sum2=(a[rak[l]]*a[rak[l]])%mod;
        //cout<<"sum1 "<<tr[i].sum1<<" sum2 "<<tr[i]sum2<<endl;
        return ;
    }
    int mid=(l+r)>>1;
    build(lson,l,mid);
    build(rson,mid+1,r);
    push_up(i);
}

void solve(int i,ll val)
{
    tr[i].lz=(tr[i].lz+val)%mod;
    ll cnt=tr[i].r-tr[i].l+1;
    ll tmp1=(tr[i].sum1*2%mod*val)%mod;
    ll tmp2=(cnt*val%mod*val)%mod;
    tr[i].sum2=(tr[i].sum2+tmp1+tmp2)%mod;
    tr[i].sum1=(tr[i].sum1+cnt*val%mod)%mod;
}

void push_down(int i)
{
    if(tr[i].lz){
        ll &lz=tr[i].lz;
        solve(lson,lz);
        solve(rson,lz);
        lz=0;
    }
}

void update(int i,int l,int r,ll val)
{
    if(tr[i].l==l&&tr[i].r==r){
        solve(i,val);
        return ;
    }
    push_down(i);
    int mid=(tr[i].l+tr[i].r)>>1;
    if(r<=mid) update(lson,l,r,val);
    else if(l>mid) update(rson,l,r,val);
    else{
        update(lson,l,mid,val);
        update(rson,mid+1,r,val);
    }
    push_up(i);
}

void query(int i,int l,int r,ll &sum1,ll &sum2)
{
    if(tr[i].l==l&&tr[i].r==r){
        sum1+=tr[i].sum1;
        sum1%=mod;
        sum2+=tr[i].sum2;
        sum2%=mod;
        return ;
    }
    push_down(i);
    int mid=(tr[i].l+tr[i].r)>>1;
    if(r<=mid) return query(lson,l,r,sum1,sum2);
    else if(l>mid) return query(rson,l,r,sum1,sum2);
    else{
        query(lson,l,mid,sum1,sum2);
        query(rson,mid+1,r,sum1,sum2);
    }
}

ll querys(int x,int y)
{
    ll sum1=0,sum2=0;
    int fx=top[x],fy=top[y];
    while(fx!=fy)
    {
        if(dep[fx]>=dep[fy]){
            query(1,idd[fx],idd[x],sum1,sum2);
            x=fat[fx]; fx=top[x];
        }
        else{
            query(1,idd[fy],idd[y],sum1,sum2);
            y=fat[fy]; fy=top[y];
        }
    }

    if(idd[x]<=idd[y]){
        query(1,idd[x],idd[y],sum1,sum2);
    }
    else{
        query(1,idd[y],idd[x],sum1,sum2);
    }
    ll ans=(sum1*sum1%mod-sum2+mod)*inv2%mod;
    return ans;
}

void updates(int x,int y,ll c)
{
    int fx=top[x]; int fy=top[y];
    while(fx!=fy)
    {
        if(dep[fx]>=dep[fy]){
            update(1,idd[fx],idd[x],c);
            x=fat[fx];
        }
        else{
            update(1,idd[fy],idd[y],c);
            y=fat[fy];
        }
        fx=top[x];
        fy=top[y];
    }
    if(idd[x]<=idd[y]) update(1,idd[x],idd[y],c);
    else update(1,idd[y],idd[x],c);
}

int main()
{
    scanf("%d %d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
    init();
    int u,v;
    for(int i=1;i<n;i++){
        scanf("%d %d",&u,&v);
        add(u,v);
        add(v,u);
    }

    dfs1(1,1,0);
    dfs2(1,1);
    build(1,1,n);

    int op;
    ll val;
    while(m--)
    {
        scanf("%d",&op);
        if(op==1){
            scanf("%d %lld",&u,&val);
            update(1,idd[u],idd[u]+siz[u]-1,val);
        }
        else if(op==2){
            scanf("%d %d %lld",&u,&v,&val);
            updates(u,v,val);
        }
        else{
            scanf("%d %d",&u,&v);
            ll Ans=querys(u,v);
            printf("%lld\n",Ans);
        }
    }

    return 0;
}

猜你喜欢

转载自blog.csdn.net/yjt9299/article/details/82591659