洛谷3233 BZOJ3572 HNOI2014 世界树 虚树 树形dp

版权声明:本文为博主原创文章,可以转载但是必须声明版权。 https://blog.csdn.net/forever_shi/article/details/84203274

题目链接

题意:
给你一棵n个点的树,边的边权都是1,有m次询问,每次选出若干个点,对于每次询问,每个点要划分给离它最近的被选出来的点,如果有多个距离相同的点,则把这个点划分给这几个距离相同的点中编号最小的点,求每次询问选出的这些点各自分得了多少个点。 n n \sum选出的点的个数 都是3e5量级的。

题解:
每次询问选出树上若干个点的题目还是考虑建出虚树之后树形dp。不过这次树形dp细节比较多,所以主要讲一下树形dp。首先我们建虚树前预处理出每个点在原树中的子树大小,然后对于每次询问,我们建出虚树,虚树上的边边权应该是原树上两点的深度差。首先对于每个被选中的点,它肯定把自己划分给自己,然后我们用两遍dfs的方法求出虚树上每个点划分给哪个点(虚树上有一些点并不是直接被选中的,而是LCA建出来的)。dfs的方法是第一遍用子节点更新父节点,第二遍是用父节点更新子节点。我在第二遍的时候顺便统计了答案,我们计算虚树上每个点对它所属的那个点的贡献,每个点的初始值是他原树上子树的大小。首先我们分两种情况讨论,如果一条边的两段都属于同一个点,那么我们一点把这两个点之间的所有点都划给那个这两个端点划给的点,因为我们是统计每一个点的贡献,所以我们在父节点处要减去子节点所在子树的大小。另一种情况就要复杂一些,如果一条边的两个端点属于两个不同的点,那么我们就要考虑在中间断开,找到原树上两点中间离两个划分点距离相同的点,这个找的过程我们用倍增来实现,这样能保证复杂度不会退化。找到之后就是把这个点两侧分别划给靠近的被选择的点。思路大体就是这样,但是代码实现中细节较多,所以如果思路不是特别清晰的话可以参考一下代码。

代码:

#include <bits/stdc++.h>
using namespace std;

int n,hed[300010],cnt,q,m;
int f[300010][21],dep[300010],dfn[300010],sz[300010],z;
int sta[300010],tp,book[300010],ans[300010],bl[300010];
int dis[300010],s[300010];
struct node
{
    int to,next;
}a[600010];
struct qwq
{
    int x,id;
}b[300010];
vector<int> v[300010];
inline void add(int from,int to)
{
    a[++cnt].to=to;
    a[cnt].next=hed[from];
    hed[from]=cnt;
}
inline void dfs(int x)
{
    sz[x]=1;
    dfn[x]=++z;
    for(int i=1;i<=20;++i)
    f[x][i]=f[f[x][i-1]][i-1];
    for(int i=hed[x];i;i=a[i].next)
    {
        int y=a[i].to;
        if(y==f[x][0])
        continue;
        f[y][0]=x;
        dep[y]=dep[x]+1;
        dfs(y);
        sz[x]+=sz[y];
    }
}
inline int cmp(qwq x,qwq y)
{
    return dfn[x.x]<dfn[y.x];
}
inline int cmp2(qwq x,qwq y)
{
    return x.id<y.id;
}
inline int lca(int x,int y)
{
    if(dep[x]<dep[y])
    swap(x,y);
    for(int i=20;i>=0;--i)
    {
        if(dep[x]-dep[y]>=(1<<i))
        x=f[x][i];
    }
    if(x==y)
    return x;
    for(int i=20;i>=0;--i)
    {
        if(f[x][i]!=f[y][i]&&dep[f[x][i]])
        {
            x=f[x][i];
            y=f[y][i];
        }
    }
    return f[x][0];
}
inline void add2(int from,int to)
{
    v[from].push_back(to);
}
inline void insert(int x)
{
    if(tp==1)
    {
        sta[++tp]=x;
        return;
    }
    int z=lca(sta[tp],x);
    while(tp>1&&dfn[z]<=dfn[sta[tp-1]])
    {
        add2(sta[tp-1],sta[tp]);
        --tp;
    }
    if(sta[tp]!=z)
    {
        add2(z,sta[tp]);
        sta[tp]=z;
    }
    sta[++tp]=x;
}
inline void dfs1(int x)
{
    if(book[x])
    {
        bl[x]=x;
        dis[x]=0;
    }
    else
    dis[x]=2e9;
    s[x]=sz[x];
    ans[x]=0;
    int ji=v[x].size();
    for(int i=0;i<ji;++i)
    {
        int y=v[x][i];
        dfs1(y);
        if(dis[x]>dis[y]+dep[y]-dep[x]||(dis[x]==dis[y]+dep[y]-dep[x]&&bl[x]>bl[y]))
        {
            bl[x]=bl[y];
            dis[x]=dis[y]+dep[y]-dep[x];
        }
    }
}
inline void dfs2(int x)
{
    int ji=v[x].size();
    for(int i=0;i<ji;++i)
    {
        int y=v[x][i];
        if(dis[y]>dis[x]+dep[y]-dep[x]||(dis[y]==dis[x]+dep[y]-dep[x]&&bl[y]>bl[x]))
        {
            bl[y]=bl[x];
            dis[y]=dis[x]+dep[y]-dep[x];
        }
        dfs2(y);		
        if(bl[x]==bl[y])
        s[x]-=sz[y];
        else
        {
            int c=dis[y]+dis[x]+dep[y]-dep[x]-1,gg=y;//c是两个被选择的点之间的距离
            c/=2;//要找中间点于是距离要除以2 
            c-=dis[y];//现在c是子节点到 中间分界点的距离 
            c=dep[y]-c;//c变为分界点的深度 
            if(dep[y]<c)
            gg=y;
            else
            {
                for(int j=20;j>=0;--j)
                {
                    if(dep[f[gg][j]]>=c)
                    gg=f[gg][j];		
                }
            }		
            if(((dis[y]+dis[x]+dep[y]-dep[x]-1)&1)&&bl[x]>bl[y]&&(dis[y]+dis[x]+dep[y]-dep[x]-1)/2-dis[y]>=0)
            gg=f[gg][0];
            s[y]+=sz[gg]-sz[y];
            s[x]-=sz[gg];
        }
        ans[bl[y]]+=s[y];
    }
    if(x==1)
    ans[bl[1]]+=s[1];
    v[x].clear();
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n-1;++i)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    dep[1]=1;
    dfs(1);
    scanf("%d",&q);
    for(int i=1;i<=q;++i)
    {
        scanf("%d",&m);
        for(int j=1;j<=m;++j)
        {
            scanf("%d",&b[j].x);
            b[j].id=j;
            book[b[j].x]=1;
        }		
        sort(b+1,b+m+1,cmp);
        sta[++tp]=1;
        for(int j=1;j<=m;++j)
        {
            if(b[j].x!=1)
            insert(b[j].x);
        }
        while(tp>0)
        {
            add2(sta[tp-1],sta[tp]);
            --tp;
        }		
        dfs1(1);
        dfs2(1);
        sort(b+1,b+m+1,cmp2);
        for(int j=1;j<=m;++j)
        printf("%d ",ans[b[j].x]);	
        printf("\n");
        for(int j=1;j<=m;++j)
        book[b[j].x]=0;
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/forever_shi/article/details/84203274
今日推荐