传送门
思路:考虑每次怎么做
,直接对于每个点对的
考虑贡献即可。
因此建出虚树然后同样对于
考虑贡献。
代码:
#include<bits/stdc++.h>
#define ri register int
using namespace std;
const int rlen=1<<18|1;
inline char gc(){
static char buf[rlen],*ib,*ob;
(ib==ob)&&(ob=(ib=buf)+fread(buf,1,rlen,stdin));
return ib==ob?-1:*ib++;
}
inline int read(){
int ans=0;
char ch=gc();
while(!isdigit(ch))ch=gc();
while(isdigit(ch))ans=((ans<<2)+ans<<1)+(ch^48),ch=gc();
return ans;
}
typedef long long ll;
const int N=1e6+5,inf=1e9;
int n,dep[N],st[N][20],id[N],dfn[N],m,tot=0,stk[N],top=0,siz[N],mxdep[N],mndep[N];
bool tg[N];
vector<int>e[N],g[N];
void dfs(int p){
dfn[p]=++tot;
for(ri i=1;i<20;++i)st[p][i]=st[st[p][i-1]][i-1];
for(ri i=0,v;i<e[p].size();++i){
if((v=e[p][i])==st[p][0])continue;
st[v][0]=p,dep[v]=dep[p]+1,dfs(v);
}
}
inline int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(ri t=dep[x]-dep[y],i=19;~i;--i)if((t>>i)&1)x=st[x][i];
if(x==y)return x;
for(ri i=19;~i;--i)if(st[x][i]^st[y][i])x=st[x][i],y=st[y][i];
return st[x][0];
}
inline void insert(const int&x){
if(top<2){stk[++top]=x;return;}
int t=lca(x,stk[top]);
if(t==stk[top]){stk[++top]=x;return;}
while(dfn[stk[top]]>dfn[t]){
if(dfn[t]>=dfn[stk[top-1]]){
g[t].push_back(stk[top--]);
if(t^stk[top])stk[++top]=t;
break;
}
g[stk[top-1]].push_back(stk[top]),--top;
}
stk[++top]=x;
}
inline bool cmp(const int&a,const int&b){return dfn[a]<dfn[b];}
ll ans1;
int ans2,ans3;
inline void update(int x,int y){
ans1-=(ll)dep[x]*siz[x]*siz[y]*2,siz[x]+=siz[y];
ans2=max(ans2,mxdep[x]+mxdep[y]-2*dep[x]),mxdep[x]=max(mxdep[x],mxdep[y]);
ans3=min(ans3,mndep[x]+mndep[y]-2*dep[x]),mndep[x]=min(mndep[x],mndep[y]);
}
inline void solve(int p,int fa){
tg[p]?(siz[p]=1,mxdep[p]=mndep[p]=dep[p]):(siz[p]=0,mxdep[p]=-inf,mndep[p]=inf);
for(ri i=0,v;i<g[p].size();++i){
if((v=g[p][i])==fa)continue;
solve(v,p),update(p,v);
}
}
inline void clear(int p,int fa){
for(ri i=0,v;i<g[p].size();++i)if((v=g[p][i])^fa)clear(v,p);
g[p].clear();
}
inline void solve(){
ans1=0,ans2=-inf,ans3=inf;
m=read();
for(ri i=1;i<=m;++i)id[i]=read(),tg[id[i]]=1,ans1+=(ll)dep[id[i]]*(m-1);
sort(id+1,id+m+1,cmp);
if(!tg[1])insert(1);
for(ri i=1;i<=m;++i)insert(id[i]);
while(top){
if(top^1)g[stk[top-1]].push_back(stk[top]);
--top;
}
solve(1,0),clear(1,0);
for(ri i=1;i<=m;++i)tg[id[i]]=0;
cout<<ans1<<' '<<ans3<<' '<<ans2<<'\n';
}
int main(){
n=read();
for(ri i=1,u,v;i<n;++i)
u=read(),v=read(),e[u].push_back(v),e[v].push_back(u);
dfs(1);
for(ri tt=read();tt;--tt)solve();
return 0;
}