CF1111E Tree 虚树 dp

版权声明:本文为博主原创文章,可以转载但是必须声明版权。 https://blog.csdn.net/forever_shi/article/details/88841087

题目链接

题意:
给一棵 n n 个结点的树, q q 次询问,每次询问首先是三个数 k , m , r k,m,r ,接下来跟着 k k 个结点编号,请你将这 k k 个结点分成不超过 m m 组,使得在以 r r 为根的情况下,组内的任意两个结点不存在祖先关系,求方案数对 1 0 9 + 7 10^9+7 取模。根不一定在这 k k 个点内。 n , q < = 1 e 5 , k < = 5 e 5 n,q<=1e5,\sum k<=5e5

题解:
这种每次读入树上若干个点,然后问这些点的信息的问题,第一反应就是去想虚树吧。这个题做法挺多的,可以树状数组/线段树+dfs序、树剖/LCT、虚树。我就讲一下虚树的做法吧。

首先我们考虑加入现在给你一棵树,让你求方案数应该怎么求。首先面对的一个问题是我们要确定一个合适的dp顺序,来保证正确性。我们的做法是按照dfs序来dp,这样能保证子树内的点在根节点之后dp,也就是我们的顺序是从根向子树dp。我们设 d p [ i ] [ j ] dp[i][j] 表示考虑了dfs序的前 i i 小的点分成了 j j 组的合法方案数。那么我们枚举当前点是新形成一个组还是加到原来的某一个组后面,我们知道,它的所有父节点都在不同的组里,那么其余的组是它可以进入的。我们设 c n t [ i ] cnt[i] 表示 i i 有多少个父节点被选中了,于是有 d p [ i ] [ j ] = d p [ i 1 ] [ j 1 ] + d p [ i 1 ] [ j ] ( i c n t [ i ] ) dp[i][j]=dp[i-1][j-1]+dp[i-1][j]*(i-cnt[i]) 。这个式子显然是可以用滚动数组优化的,只保留一维的话写的时候要类似背包那样从大到小枚举,以免算进去一些当前点放进去很多次的情况。

那么解决了给出你树怎么做之后,对于原题,就只需要把每次询问给出的点建出虚树来就可以了。我一开始以为这个换根之后可能会出错,但是想了想发现建的虚树并没有根,是个无根树,那么我们在做的时候从当前规定的根开始在虚树上一边从上向下dfs一边dp就可以了。多组询问注意一下各种清空数组和变量信息。这样就做完了。

代码:

#include <bits/stdc++.h>
using namespace std;

int n,q,hed[100010],cnt,num,xu[100010],f[100010][21],dep[100010];
int m,k,rt,book[100010],sta[500010],tp;
vector<int> v[100010],b;
const long long mod=1e9+7;
long long ans,dp[310];
struct node
{
	int to,next;
}a[200010];
inline int read()
{
	int x=0;
	char s=getchar();
	while(s>'9'||s<'0')
	s=getchar();
	while(s>='0'&&s<='9')
	{
		x=x*10+s-'0';
		s=getchar();
	}
	return x;
}
inline void add(int from,int to)
{
	a[++cnt].to=to;
	a[cnt].next=hed[from];
	hed[from]=cnt;
}
inline void dfs(int x)
{
	xu[x]=++num;
	for(int i=1;i<=20;++i)
	f[x][i]=f[f[x][i-1]][i-1];
	for(int i=hed[x];i;i=a[i].next)
	{
		int y=a[i].to;
		if(y==f[x][0])
		continue;
		f[y][0]=x;
		dep[y]=dep[x]+1;
		dfs(y);
	}
}
inline int cmp(int x,int y)
{
	return xu[x]<xu[y];
}
inline int lca(int x,int y)
{
	if(dep[x]<dep[y])
	swap(x,y);
	for(int i=20;i>=0;--i)
	{
		if(dep[f[x][i]]>=dep[y])
		x=f[x][i];
	}
	if(x==y)
	return x;
	for(int i=20;i>=0;--i)
	{
		if(f[x][i]!=f[y][i]&&dep[f[x][i]])
		{
			x=f[x][i];
			y=f[y][i];
		}
	}
	return f[x][0];
}
inline void add2(int from,int to)
{
	v[from].push_back(to);
	v[to].push_back(from);
}
inline void ins(int x)
{
	if(x==1)
	return;
	if(tp==1)
	{
		sta[++tp]=x;
		return;
	}
	int z=lca(sta[tp],x);
	while(tp>1&&xu[z]<=xu[sta[tp-1]])
	{
		add2(sta[tp-1],sta[tp]);
		--tp;
	}
	if(sta[tp]!=z)
	{
		add2(sta[tp],z);
		sta[tp]=z;
	}
	sta[++tp]=x;
}
inline void solve(int x,int fa,int cnt)
{
	if(book[x])
	{
		for(int i=m;i>=0;--i)
		{
			if(i<=cnt)
			dp[i]=0;
			else
			dp[i]=(dp[i-1]+dp[i]*(i-cnt)%mod)%mod;
		}
	}
	int sz=v[x].size();
	for(int i=0;i<sz;++i)
	{
		int y=v[x][i];
		if(y==fa)
		continue;
		solve(y,x,cnt+book[x]);
	}
	v[x].clear();
	book[x]=0;
}
int main()
{
	n=read();
	q=read();
	for(int i=1;i<=n-1;++i)
	{
		int x=read(),y=read();
		add(x,y);
		add(y,x);
	}	
	dep[1]=1;
	dfs(1);
	for(int qwq=1;qwq<=q;++qwq)
	{
		k=read();
		m=read();
		rt=read();
		b.clear();
		for(int i=1;i<=k;++i)
		{
			int x=read();
			book[x]=1;
			b.push_back(x);
		}
		if(!book[rt])
		{
			++k;
			b.push_back(rt);
		}		
		sort(b.begin(),b.end(),cmp);
		tp=0;
		sta[++tp]=1;
		for(int i=0;i<k;++i)
		ins(b[i]);		
		while(tp)
		{
			if(tp-1)
			add2(sta[tp-1],sta[tp]);
			--tp;
		}
		memset(dp,0,sizeof(dp));
		dp[0]=1;
		solve(rt,0,0);
		ans=0;
		for(int i=0;i<=m;++i)
		ans=(ans+dp[i])%mod;
		printf("%I64d\n",ans);
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/forever_shi/article/details/88841087