首先我们还是针对每一次询问建出虚树。在虚树上搞事情,这样复杂度就只与关键点个数有关了。
我们考虑一棵虚树怎么统计答案。我们首先两遍dfs求出虚树上每个点最近的关键点,记作bel。
然后考虑虚树上的每一条边x->y.首先倍增的求出x在原树上的儿子(即x->y路径上的第一个点z)。那么原树上在x->y之间的就有sz[z]-sz[y]个点。(即原树上x->y这条链(不含首尾)及挂在这条链上的所有点)。
如果bel[x]==bel[y],那么这些点均加到ans[bel[x]]上即可。
否则在x->y这条链上一定存在一个分界点mid,使得mid~y的点均属于bel[y],x~mid-1的点均属于bel[x]。我们二分求出这个mid,然后统计答案。
挂在虚树上每个点上的还有一些点没有被计算过,我们最后单独统计。(rem[x]表示挂在x这个点上的点,一开始是sz[x],不断减去虚树上自己儿子的sz即可)
复杂度
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
#define ll long long
#define inf 0x3f3f3f3f
#define N 300010
inline char gc(){
static char buf[1<<16],*S,*T;
if(T==S){T=(S=buf)+fread(buf,1,1<<16,stdin);if(S==T) return EOF;}
return *S++;
}
inline int read(){
int x=0,f=1;char ch=gc();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=gc();
return x*f;
}
int n,m,tot,h[N],num=0,fa[N][20],dep[N],sz[N],dfn[N],dfnum=0;
int a[N],b[N],qq[N],c[N],bel[N],ans[N],rem[N],Log[N];
struct edge{
int to,next;
}data[N<<1];
inline void add(int x,int y){
data[++num].to=y;data[num].next=h[x];h[x]=num;
}
inline void dfs(int x){
sz[x]=1;dfn[x]=++dfnum;
for(int i=1;i<=Log[n];++i){
if(!fa[x][i-1]) break;
fa[x][i]=fa[fa[x][i-1]][i-1];
}for(int i=h[x];i;i=data[i].next){
int y=data[i].to;if(y==fa[x][0]) continue;
fa[y][0]=x;dep[y]=dep[x]+1;dfs(y);sz[x]+=sz[y];
}
}
inline bool cmp(int x,int y){return dfn[x]<dfn[y];}
inline int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
int d=dep[x]-dep[y];
for(int i=0;i<=Log[d];++i)
if(d>>i&1) x=fa[x][i];
if(x==y) return x;
for(int i=Log[n];i>=0;--i)
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
inline int dis(int x,int y){return dep[x]+dep[y]-2*dep[lca(x,y)];}
inline void dfs1(int x,int Fa){
c[++tot]=x;rem[x]=sz[x];
for(int i=h[x];i;i=data[i].next){
int y=data[i].to;dfs1(y,x);
if(!bel[x]){bel[x]=bel[y];continue;}
int d1=dis(bel[x],x),d2=dis(bel[y],x);
if(d2<d1||d2==d1&&bel[y]<bel[x]) bel[x]=bel[y];
}
}
inline void dfs2(int x,int Fa){
for(int i=h[x];i;i=data[i].next){
int y=data[i].to;
int d1=dis(bel[y],y),d2=dis(bel[x],y);
if(d2<d1||d2==d1&&bel[x]<bel[y]) bel[y]=bel[x];dfs2(y,x);
}
}
inline void calc(int x,int y){//x->y
int z=y;//原树上x的儿子z
for(int i=Log[dep[y]];i>=0;--i)
if(dep[fa[z][i]]>dep[x]) z=fa[z][i];
rem[x]-=sz[z];
if(bel[x]==bel[y]){ans[bel[x]]+=sz[z]-sz[y];return;}
int mid=y;//mid~y均属于bel[y],mid+1~x均属于bel[x]
for(int i=Log[dep[y]];i>=0;--i){
if(dep[fa[mid][i]]<=dep[x]) continue;
int xx=fa[mid][i],d1=dis(xx,bel[x]),d2=dis(xx,bel[y]);
if(d2<d1||d2==d1&&bel[y]<bel[x]) mid=xx;
}ans[bel[x]]+=sz[z]-sz[mid];
ans[bel[y]]+=sz[mid]-sz[y];
}
inline void solve(){
m=read();tot=0;num=0;int top=0;
for(int i=1;i<=m;++i) a[i]=b[i]=read(),bel[a[i]]=a[i];
sort(a+1,a+m+1,cmp);qq[++top]=1;
for(int i=1;i<=m;++i){
int t=lca(qq[top],a[i]);
while(dep[qq[top]]>dep[t]){
int x=qq[top--];
if(dep[qq[top]]<dep[t]) qq[++top]=t;
add(qq[top],x);
}if(qq[top]!=a[i]) qq[++top]=a[i];
}int x=qq[top--];while(top) add(qq[top],x),x=qq[top--];
dfs1(1,0);dfs2(1,0);
for(int i=1;i<=tot;++i)
for(int j=h[c[i]];j;j=data[j].next)
calc(c[i],data[j].to);
for(int i=1;i<=tot;++i) ans[bel[c[i]]]+=rem[c[i]];
for(int i=1;i<=m;++i) printf("%d ",ans[b[i]]);puts("");
for(int i=1;i<=tot;++i) bel[c[i]]=h[c[i]]=ans[c[i]]=0;
}
int main(){
// freopen("a.in","r",stdin);
n=read();Log[0]=-1;
for(int i=1;i<=n;++i) Log[i]=Log[i>>1]+1;
for(int i=1;i<n;++i){
int x=read(),y=read();add(x,y);add(y,x);
}dfs(1);memset(h,0,sizeof(h));
int owo=read();while(owo--) solve();
return 0;
}