题意:
给一棵
个结点的树,
次询问,每次询问首先是三个数
,接下来跟着
个结点编号,请你将这
个结点分成不超过
组,使得在以
为根的情况下,组内的任意两个结点不存在祖先关系,求方案数对
取模。根不一定在这
个点内。
题解:
这种每次读入树上若干个点,然后问这些点的信息的问题,第一反应就是去想虚树吧。这个题做法挺多的,可以树状数组/线段树+dfs序、树剖/LCT、虚树。我就讲一下虚树的做法吧。
首先我们考虑加入现在给你一棵树,让你求方案数应该怎么求。首先面对的一个问题是我们要确定一个合适的dp顺序,来保证正确性。我们的做法是按照dfs序来dp,这样能保证子树内的点在根节点之后dp,也就是我们的顺序是从根向子树dp。我们设 表示考虑了dfs序的前 小的点分成了 组的合法方案数。那么我们枚举当前点是新形成一个组还是加到原来的某一个组后面,我们知道,它的所有父节点都在不同的组里,那么其余的组是它可以进入的。我们设 表示 有多少个父节点被选中了,于是有 。这个式子显然是可以用滚动数组优化的,只保留一维的话写的时候要类似背包那样从大到小枚举,以免算进去一些当前点放进去很多次的情况。
那么解决了给出你树怎么做之后,对于原题,就只需要把每次询问给出的点建出虚树来就可以了。我一开始以为这个换根之后可能会出错,但是想了想发现建的虚树并没有根,是个无根树,那么我们在做的时候从当前规定的根开始在虚树上一边从上向下dfs一边dp就可以了。多组询问注意一下各种清空数组和变量信息。这样就做完了。
代码:
#include <bits/stdc++.h>
using namespace std;
int n,q,hed[100010],cnt,num,xu[100010],f[100010][21],dep[100010];
int m,k,rt,book[100010],sta[500010],tp;
vector<int> v[100010],b;
const long long mod=1e9+7;
long long ans,dp[310];
struct node
{
int to,next;
}a[200010];
inline int read()
{
int x=0;
char s=getchar();
while(s>'9'||s<'0')
s=getchar();
while(s>='0'&&s<='9')
{
x=x*10+s-'0';
s=getchar();
}
return x;
}
inline void add(int from,int to)
{
a[++cnt].to=to;
a[cnt].next=hed[from];
hed[from]=cnt;
}
inline void dfs(int x)
{
xu[x]=++num;
for(int i=1;i<=20;++i)
f[x][i]=f[f[x][i-1]][i-1];
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
if(y==f[x][0])
continue;
f[y][0]=x;
dep[y]=dep[x]+1;
dfs(y);
}
}
inline int cmp(int x,int y)
{
return xu[x]<xu[y];
}
inline int lca(int x,int y)
{
if(dep[x]<dep[y])
swap(x,y);
for(int i=20;i>=0;--i)
{
if(dep[f[x][i]]>=dep[y])
x=f[x][i];
}
if(x==y)
return x;
for(int i=20;i>=0;--i)
{
if(f[x][i]!=f[y][i]&&dep[f[x][i]])
{
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
inline void add2(int from,int to)
{
v[from].push_back(to);
v[to].push_back(from);
}
inline void ins(int x)
{
if(x==1)
return;
if(tp==1)
{
sta[++tp]=x;
return;
}
int z=lca(sta[tp],x);
while(tp>1&&xu[z]<=xu[sta[tp-1]])
{
add2(sta[tp-1],sta[tp]);
--tp;
}
if(sta[tp]!=z)
{
add2(sta[tp],z);
sta[tp]=z;
}
sta[++tp]=x;
}
inline void solve(int x,int fa,int cnt)
{
if(book[x])
{
for(int i=m;i>=0;--i)
{
if(i<=cnt)
dp[i]=0;
else
dp[i]=(dp[i-1]+dp[i]*(i-cnt)%mod)%mod;
}
}
int sz=v[x].size();
for(int i=0;i<sz;++i)
{
int y=v[x][i];
if(y==fa)
continue;
solve(y,x,cnt+book[x]);
}
v[x].clear();
book[x]=0;
}
int main()
{
n=read();
q=read();
for(int i=1;i<=n-1;++i)
{
int x=read(),y=read();
add(x,y);
add(y,x);
}
dep[1]=1;
dfs(1);
for(int qwq=1;qwq<=q;++qwq)
{
k=read();
m=read();
rt=read();
b.clear();
for(int i=1;i<=k;++i)
{
int x=read();
book[x]=1;
b.push_back(x);
}
if(!book[rt])
{
++k;
b.push_back(rt);
}
sort(b.begin(),b.end(),cmp);
tp=0;
sta[++tp]=1;
for(int i=0;i<k;++i)
ins(b[i]);
while(tp)
{
if(tp-1)
add2(sta[tp-1],sta[tp]);
--tp;
}
memset(dp,0,sizeof(dp));
dp[0]=1;
solve(rt,0,0);
ans=0;
for(int i=0;i<=m;++i)
ans=(ans+dp[i])%mod;
printf("%I64d\n",ans);
}
return 0;
}