[Codeforces1097G] Vladislav and a Great Legend

链接:https://codeforces.com/contest/1097/problem/G
大概说一下题意吧:
一棵n个点的树,一个点集S的权值定义为把这个点集连成一个联通块的最少边数
求所有点集的 f ( S ) k f(S)^k 的和

对于这种带次方的,一般考虑两个方法
一个是二项式展开,一个是斯特林数展开
不知道斯特林数的可以看看这个第二类斯特林数
二项式展开我不是特别会用(虽然斯特林数也不会)
一开始是想用前者做的,但是发现不是很会写,并且二项式展开的话复杂度是稳定至少 O ( n k 2 ) O(nk^2) 的。。并不可以通过
但是如果斯特林数可以和子树大小挂钩,那么复杂度就可以降为 O ( n k ) O(nk)

先把式子写成 a n s = i = 0 k { k i } i ! S U ( f ( S ) i ) ans=\sum_{i=0}^k \begin{Bmatrix}k\\i\end{Bmatrix} i! \sum_{S⊆ U}\begin{pmatrix}f(S)\\i\end{pmatrix}
你会发现,右边那个组合数的意义是,我们选择 i i 条边,对应多少个不同的点集
于是就可以考虑DP了
f i , j f_{i,j} 表示以i为根的子树,里面选择了 j j 条边,有多少种选点方案
转移的话
就是先把儿子的合并,然后加上自己的
当然,有两个条件是不合法的要减去
第一个是,子树外面没有选点,但是我们选了x到父亲这条边
第二个是,子树里面没有选点,但是我们选了x到父亲这条边
两个情况均要在DP的时候暴力去掉
但是在去掉第一个的时候,要小心,别把包括第二个也顺便去掉了

CODE:

#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
using namespace std;
typedef long long LL;
const LL MOD=1e9+7;
const LL N=100005;
const LL K=205;
LL n,k;
struct qq
{
	LL x,y,last;
}e[N*2];LL num,last[N];
void init (LL x,LL y)
{
	num++;e[num].x=x;e[num].y=y;
	e[num].last=last[x];
	last[x]=num;
}
LL S[K][K];
LL JC[K];
LL h[K];
LL siz[N];
LL f[N][K];
LL g[N];
void dfs (LL x,LL fa)
{
	siz[x]=1;
	f[x][0]=2;
	for (LL xx=last[x];xx!=-1;xx=e[xx].last)
	{
		LL y=e[xx].y;
		if (y==fa) continue;
		dfs(y,x);
		for (LL u=0;u<=min(siz[x]+siz[y]-1,k);u++) g[u]=0;
		for (LL u=0;u<siz[x]&&u<=k;u++)
			for (LL i=0;i<=siz[y]&&(u+i)<=k;i++)
				g[u+i]=(g[u+i]+f[x][u]*f[y][i]%MOD)%MOD;
		siz[x]=siz[x]+siz[y];
		for (LL u=0;u<=min(siz[x]-1,k);u++) f[x][u]=g[u];
	}
	if (x==1)
	{
		for (LL u=0;u<=k;u++) h[u]=h[u]+f[x][u];
	}
	else
	{
		for (LL u=1;u<=k;u++) h[u]=(h[u]-f[x][u-1])%MOD;
		h[1]=(h[1]+1)%MOD;
	}
	for (LL u=k;u>=1;u--) f[x][u]=(f[x][u]+f[x][u-1])%MOD;
	f[x][1]=(f[x][1]-1+MOD)%MOD;
}
int main()
{
	num=0;memset(last,-1,sizeof(last));
	scanf("%lld%lld",&n,&k);
	JC[0]=1;for (LL u=1;u<=k;u++) JC[u]=JC[u-1]*u%MOD;
	S[0][0]=1;
	for (LL u=1;u<=k;u++)
		for (LL i=1;i<=u;i++)
			S[u][i]=(S[u-1][i-1]+S[u-1][i]*i%MOD)%MOD;
	for (LL u=1;u<n;u++)
	{
		LL x,y;
		scanf("%lld%lld",&x,&y);
		init(x,y);init(y,x);
	}
	dfs(1,0);
	LL ans=0;
	for (LL u=0;u<=k;u++) ans=(ans+S[k][u]*JC[u]%MOD*h[u]%MOD)%MOD;
	ans=(ans+MOD)%MOD;
	printf("%lld\n",ans);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_36797743/article/details/86012128