bzoj3572 [Hnoi2014] World tree (virtual tree + tree dp + multiplication + bisection)

First, we build a virtual tree for each query. Do things on the virtual tree, so that the complexity is only related to the number of key points.

We consider how an imaginary tree counts the answers. We first find the nearest key point of each point on the virtual tree by dfs twice, denoted by bel.
Then consider each edge x->y on the virtual tree. First multiply the son of x on the original tree (ie the first point z on the x->y path). Then there are sz[z]-sz[y] points between x->y on the original tree. (That is, the chain of x->y on the original tree (excluding the head and tail) and all points hanging on this chain).
If bel[x]==bel[y], then these points can be added to ans[bel[x]].
Otherwise, there must be a dividing point mid on the chain of x->y, so that the points of mid~y belong to bel[y], and the points of x~mid-1 belong to bel[x]. We find this mid by bisection, and then count the answers.

There are still some points hanging on each point on the virtual tree that have not been calculated, and we finally count them separately. (rem[x] represents the point hanging on the point of x. It is sz[x] at the beginning, and the sz of his son on the virtual tree can be continuously subtracted)

the complexity O ( n l og2n)

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
#define ll long long
#define inf 0x3f3f3f3f
#define N 300010
inline char gc(){
    static char buf[1<<16],*S,*T;
    if(T==S){T=(S=buf)+fread(buf,1,1<<16,stdin);if(S==T) return EOF;}
    return *S++;
}
inline int read(){
    int x=0,f=1;char ch=gc();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
    while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=gc();
    return x*f;
}
int n,m,tot,h[N],num=0,fa[N][20],dep[N],sz[N],dfn[N],dfnum=0;
int a[N],b[N],qq[N],c[N],bel[N],ans[N],rem[N],Log[N];
struct edge{
    int to,next;
}data[N<<1];
inline void add(int x,int y){
    data[++num].to=y;data[num].next=h[x];h[x]=num;
}
inline void dfs(int x){
    sz[x]=1;dfn[x]=++dfnum;
    for(int i=1;i<=Log[n];++i){
        if(!fa[x][i-1]) break;
        fa[x][i]=fa[fa[x][i-1]][i-1];
    }for(int i=h[x];i;i=data[i].next){
        int y=data[i].to;if(y==fa[x][0]) continue;
        fa[y][0]=x;dep[y]=dep[x]+1;dfs(y);sz[x]+=sz[y];
    }
}
inline bool cmp(int x,int y){return dfn[x]<dfn[y];}
inline int lca(int x,int y){
    if(dep[x]<dep[y]) swap(x,y);
    int d=dep[x]-dep[y];
    for(int i=0;i<=Log[d];++i)
        if(d>>i&1) x=fa[x][i];
    if(x==y) return x;
    for(int i=Log[n];i>=0;--i)
        if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
inline int dis(int x,int y){return dep[x]+dep[y]-2*dep[lca(x,y)];}
inline void dfs1(int x,int Fa){
    c[++tot]=x;rem[x]=sz[x];
    for(int i=h[x];i;i=data[i].next){
        int y=data[i].to;dfs1(y,x);
        if(!bel[x]){bel[x]=bel[y];continue;}
        int d1=dis(bel[x],x),d2=dis(bel[y],x);
        if(d2<d1||d2==d1&&bel[y]<bel[x]) bel[x]=bel[y];
    }
}
inline void dfs2(int x,int Fa){
    for(int i=h[x];i;i=data[i].next){
        int y=data[i].to;
        int d1=dis(bel[y],y),d2=dis(bel[x],y);
        if(d2<d1||d2==d1&&bel[x]<bel[y]) bel[y]=bel[x];dfs2(y,x);
    }
}
inline void calc(int x,int y){//x->y
    int z=y;//原树上x的儿子z
    for(int i=Log[dep[y]];i>=0;--i)
        if(dep[fa[z][i]]>dep[x]) z=fa[z][i];
    rem[x]-=sz[z];
    if(bel[x]==bel[y]){ans[bel[x]]+=sz[z]-sz[y];return;}
    int mid=y;//mid~y均属于bel[y],mid+1~x均属于bel[x]
    for(int i=Log[dep[y]];i>=0;--i){
        if(dep[fa[mid][i]]<=dep[x]) continue;
        int xx=fa[mid][i],d1=dis(xx,bel[x]),d2=dis(xx,bel[y]);
        if(d2<d1||d2==d1&&bel[y]<bel[x]) mid=xx;
    }ans[bel[x]]+=sz[z]-sz[mid];
    ans[bel[y]]+=sz[mid]-sz[y];
}
inline void solve(){
    m=read();tot=0;num=0;int top=0;
    for(int i=1;i<=m;++i) a[i]=b[i]=read(),bel[a[i]]=a[i];
    sort(a+1,a+m+1,cmp);qq[++top]=1;
    for(int i=1;i<=m;++i){
        int t=lca(qq[top],a[i]);
        while(dep[qq[top]]>dep[t]){
            int x=qq[top--];
            if(dep[qq[top]]<dep[t]) qq[++top]=t;
            add(qq[top],x);
        }if(qq[top]!=a[i]) qq[++top]=a[i];
    }int x=qq[top--];while(top) add(qq[top],x),x=qq[top--];
    dfs1(1,0);dfs2(1,0);
    for(int i=1;i<=tot;++i)
        for(int j=h[c[i]];j;j=data[j].next)
            calc(c[i],data[j].to);
    for(int i=1;i<=tot;++i) ans[bel[c[i]]]+=rem[c[i]];
    for(int i=1;i<=m;++i) printf("%d ",ans[b[i]]);puts("");
    for(int i=1;i<=tot;++i) bel[c[i]]=h[c[i]]=ans[c[i]]=0;
}
int main(){
//  freopen("a.in","r",stdin);
    n=read();Log[0]=-1;
    for(int i=1;i<=n;++i) Log[i]=Log[i>>1]+1;
    for(int i=1;i<n;++i){
        int x=read(),y=read();add(x,y);add(y,x);
    }dfs(1);memset(h,0,sizeof(h));
    int owo=read();while(owo--) solve();
    return 0;
}

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=324855868&siteId=291194637