jzoj 5783. 【省选模拟2018.8.8】树 lca+线段树

题目大意:
给你 n ( n <= 3 10 5 ) 个点的一棵树,初始根为 1 ,支持 3 种操作。
1. 把根换成 x
2. x y 两点的 l c a 的子数每个点权值 + x
3. 询问以 x 为根的子树权值和。

分析:
主要是换根操作。先以 1 为根跑 d f s ,然后对两点 l c a 分类讨论。
如果 x y 都为 r o o t 的儿子,且 l c a 不为 r o o t ,那么给他们的子树直接加。
如果一个是 r o o t 的儿子,另一个不是,或者 l c a r o o t ,给整棵树加权值。
如果两个都不是 r o o t 儿子的,且 l c a 不是 r o o t 的父亲,直接给子树加。
如果两个都不是 r o o t 儿子的, l c a r o o t 的父亲,此时找到 x r o o t y r o o t l c a ,设深度大的那个 l c a d ,那么这棵树除了 d 中包含 r o o t 的子树,其他的都加权值。线段树维护即可。

代码:

#include <iostream>
#include <cmath>
#include <cstdio>
#define LL long long

const int maxn=3e5+7;

using namespace std;

struct edge{
    int y,next;
}g[maxn*2];

struct node{
    LL lazy;
    LL sum;
}t[maxn*4];

int n,test,x,y,cnt,root,op;
int ls[maxn],dfn[maxn],last[maxn],dep[maxn];
LL a[maxn],k;
int f[maxn][20];

void add(int x,int y)
{
    g[++cnt]=(edge){y,ls[x]};
    ls[x]=cnt;
}

void dfs(int x,int fa)
{
    dfn[x]=++cnt;
    last[x]=dfn[x];
    f[x][0]=fa;
    for (int i=ls[x];i>0;i=g[i].next)
    {
        int y=g[i].y;
        if (y==fa) continue;
        dep[y]=dep[x]+1;
        dfs(y,x);
        last[x]=max(last[x],last[y]);
    }
}

void clean(int p,int l,int r)
{
    if (t[p].lazy)
    {
        int mid=(l+r)/2;
        t[p*2].lazy+=t[p].lazy;
        t[p*2].sum+=(LL)(mid-l+1)*t[p].lazy;
        t[p*2+1].lazy+=t[p].lazy;
        t[p*2+1].sum+=(LL)(r-mid)*t[p].lazy;
        t[p].lazy=0;
    }
}

void ins(int p,int l,int r,int x,int y,LL k)
{
    if ((l==x) && (r==y))
    {
        t[p].sum+=(LL)(r-l+1)*k;
        t[p].lazy+=k;
        return;
    }
    int mid=(l+r)/2;
    clean(p,l,r);
    if (y<=mid) ins(p*2,l,mid,x,y,k);
    else if (x>mid) ins(p*2+1,mid+1,r,x,y,k);
    else
    {
        ins(p*2,l,mid,x,mid,k);
        ins(p*2+1,mid+1,r,mid+1,y,k);
    }
    t[p].sum=t[p*2].sum+t[p*2+1].sum;
}

LL getsum(int p,int l,int r,int x,int y)
{
    if ((l==x) && (r==y)) return t[p].sum;
    int mid=(l+r)/2;
    clean(p,l,r);
    if (y<=mid) return getsum(p*2,l,mid,x,y);
    else if (x>mid) return getsum(p*2+1,mid+1,r,x,y);
    else
    {
        return getsum(p*2,l,mid,x,mid)+getsum(p*2+1,mid+1,r,mid+1,y);
    }
}

bool isfa(int x,int y)
{
    return ((dfn[x]<=dfn[y]) && (last[x]>=last[y]));
}

int up(int x,int d)
{
    int k=19,t=1<<19;
    while (d)
    {
        if (d>=t) x=f[x][k],d-=t;
        t/=2; k--;
    }
    return x;
}

int lca(int x,int y)
{
    if (dep[x]>dep[y]) swap(x,y);
    int d=dep[y]-dep[x];
    y=up(y,d);
    if (x==y) return x;
    int k=19;
    while (k>=0)
    {
        if (f[x][k]!=f[y][k])
        {
            x=f[x][k];
            y=f[y][k];
        }
        k--;
    }
    return f[x][0];
}

int main()
{
    freopen("tree.in","r",stdin);
    freopen("tree.out","w",stdout);
    scanf("%d%d",&n,&test);
    for (int i=1;i<=n;i++) scanf("%lld",&a[i]);
    for (int i=1;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    cnt=0;  
    dfs(1,0);
    for (int j=1;j<20;j++)
    {
        for (int i=1;i<=n;i++)
        {
            f[i][j]=f[f[i][j-1]][j-1];
        }
    }   
    root=1;
    for (int i=1;i<=n;i++) ins(1,1,n,dfn[i],dfn[i],a[i]);   
    for (int i=1;i<=test;i++)
    {
        scanf("%d",&op);        
        if (op==1) scanf("%d",&root);
        if (op==2)
        {           
            scanf("%d%d%lld",&x,&y,&k);
            if (isfa(root,x) && isfa(root,y)) 
            {
                int d=lca(x,y);
                if (d==root) ins(1,1,n,1,n,k);
                        else ins(1,1,n,dfn[d],last[d],k);
            }
            else
            {
                if (isfa(root,x) || isfa(root,y)) ins(1,1,n,1,n,k);
                else
                {
                    int d=lca(x,y);
                    if (!isfa(d,root))
                    {   
                        ins(1,1,n,dfn[d],last[d],k);
                    }
                    else
                    {
                        int d1=lca(x,root);
                        int d2=lca(y,root);
                        if (dep[d1]<dep[d2]) swap(d1,d2);
                        int c=up(root,dep[root]-dep[d1]-1);
                        if (dfn[c]-1>=1) ins(1,1,n,1,dfn[c]-1,k);
                        if (last[c]+1<=n) ins(1,1,n,last[c]+1,n,k);
                    }
                }
            }
        }
        if (op==3)
        {
            scanf("%d",&x);
            if (isfa(root,x))
            {
                if (x==root) printf("%lld\n",getsum(1,1,n,1,n));
                        else printf("%lld\n",getsum(1,1,n,dfn[x],last[x]));
            }
            else
            {
                if (!isfa(x,root)) printf("%lld\n",getsum(1,1,n,dfn[x],last[x]));
                else
                {
                    int d=lca(x,root);
                    int c=up(root,dep[root]-dep[d]-1);
                    LL ans=0;
                    if (dfn[c]-1>=1) ans+=getsum(1,1,n,1,dfn[c]-1);
                    if (last[c]+1<=n) ans+=getsum(1,1,n,last[c]+1,n);
                    printf("%lld\n",ans);
                }   
            }
        }
    }
}

猜你喜欢

转载自blog.csdn.net/liangzihao1/article/details/81510437
今日推荐