【JSOI2019】神经网络(树上背包)(生成函数)(容斥原理)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/zxyoi_dreamer/article/details/102750569

传送门


题解:

由于树与树之间连成了完全图,我们实际上要考虑的每棵树上拿多少条链出来,这个可以直接树上背包求一下,注意这里的链要考虑方向。

现在要把这些树链拿出来排成一个环,环可以旋转不能翻转。且相邻两个位置上放的树链不能来自同一棵树。

设第 i i 棵树拿出 j j 条链排成一个排列的方案数是 f i , j f_{i,j} ,对于“不能来自同一棵树”这个条件,我们考虑容斥,枚举有多少个断点合并。

考虑把这个生成函数写出来: F i ( x ) = j = 1 k i f i , j k = 1 j ( j 1 k 1 ) ( 1 ) j k x j j ! F_i(x)=\sum_{j=1}^{k_i}f_{i,j}\sum_{k=1}^j{j-1\choose k-1}(-1)^{j-k}\frac{x^j}{j!}

把所有生成函数乘起来就是序列的答案,现在考虑处理环。

环的话,我们需要考虑固定一条链为头,并且不能让环头和末端来自同一棵树。

固定环头需要我们把生成函数变成
F i ( x ) = j = 1 k i f i , j j k = 1 j ( j 1 k 1 ) ( 1 ) j k x j 1 ( j 1 ) ! F_i(x)=\sum_{j=1}^{k_i}\frac{f_{i,j}}{j}\sum_{k=1}^j{j-1\choose k-1}(-1)^{j-k}\frac{x^{j-1}}{(j-1)!}

然后减掉环头和末端相同的情况,也就是 j = 1 k i f i , j j k = 1 j ( j 1 k 1 ) ( 1 ) j k x j 2 ( j 2 ) ! \sum_{j=1}^{k_i}\frac{f_{i,j}}{j}\sum_{k=1}^j{j-1\choose k-1}(-1)^{j-k}\frac{x^{j-2}}{(j-2)!}

然后就没了,这些生成函数暴力乘起来即可。


代码:

#include<bits/stdc++.h>
#define ll long long
#define re register
#define cs const

namespace IO{
	inline char gc(){
		static cs int Rlen=1<<22|1;
		static char buf[Rlen],*p1,*p2;
		return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;
	}
	template<typename T>
	inline T get(){
		char c;T num;
		while(!isdigit(c=gc()));num=c^48;
		while(isdigit(c=gc()))num=(num+(num<<2)<<1)+(c^48);
		return num;
	}
	inline int gi(){return get<int>();}
}
using namespace IO;

using std::cerr;
using std::cout;

cs int mod=998244353,inv2=mod+1>>1;
inline int add(int a,int b){a+=b-mod;return a+(a>>31&mod);}
inline int dec(int a,int b){a-=b;return a+(a>>31&mod);}
inline int mul(int a,int b){ll r=(ll)a*b;return r>=mod?r%mod:r;}
inline int power(int a,int b){
	int r=1;for(;b;b>>=1,a=mul(a,a))
	if(b&1)r=mul(r,a);return r;
}
inline void Inc(int &a,int b){a+=b-mod;a+=a>>31&mod;}
inline void Dec(int &a,int b){a-=b;a+=a>>31&mod;}
inline void Mul(int &a,int b){a=mul(a,b);}
inline void ex_gcd(int a,int b,int &x,int &y){
	if(!b){x=1,y=0;return ;}ex_gcd(b,a%b,y,x);y-=a/b*x;
}
inline int inv(int a){
	int x,y;ex_gcd(mod,a,y,x);
	return x+(x>>31&mod);
}
cs int N=5e3+7;

int fac[N],ifc[N];
inline void init_fac(){
	fac[0]=1;for(int re i=1;i<N;++i)fac[i]=mul(fac[i-1],i);
	ifc[N-1]=inv(fac[N-1]);
	for(int re i=N-1;i;--i)ifc[i-1]=mul(ifc[i],i);
}
inline int C(int n,int m){return mul(fac[n],mul(ifc[n-m],ifc[m]));}

int n,k,tot;
int el[N],nxt[N<<1],to[N<<1],ec;
inline void adde(int u,int v){
	nxt[++ec]=el[u],el[u]=ec,to[ec]=v;
	nxt[++ec]=el[v],el[v]=ec,to[ec]=u;
}

int dp[N][N][3],siz[N];
inline void mg(int u,int v){
	int a=siz[u],b=siz[v];siz[u]+=siz[v];
	static int tmp[N][3];memset(tmp,0,(a+b+1)*3<<2);
	for(int re i=1;i<=b;++i){
		int all=add(dp[v][i][0],add(dp[v][i][1],dp[v][i][2]));
		for(int re j=1;j<=a;++j){
			Inc(tmp[i+j][0],mul(dp[u][j][0],all));
			Inc(tmp[i+j][1],mul(dp[u][j][1],all));
			Inc(tmp[i+j][2],mul(dp[u][j][2],all));
			Inc(tmp[i+j-1][1],mul(dp[u][j][0],add(dp[v][i][1],mul(dp[v][i][0],2))));
			Inc(tmp[i+j-1][2],mul(dp[u][j][1],add(dp[v][i][0],mul(dp[v][i][1],inv2))));
		}
	}
	memcpy(dp[u],tmp,(a+b+1)*3<<2);
}

void dfs(int u,int p){
	memset(dp[u],0,(siz[u]+1)*3<<2);
	dp[u][1][0]=siz[u]=1;
	for(int re e=el[u],v=to[e];e;v=to[e=nxt[e]])
	if(v!=p)dfs(v,u),mg(u,v);
}

int ta[N],tb[N],ct[N];
inline void poly_mul(int la,int lb){
	static int tp[N];
	memset(tp,0,la+lb<<2);
	for(int re i=0;i<la;++i)
	for(int re j=0;j<lb;++j)Inc(tp[i+j],mul(ta[i],tb[j]));
	memcpy(ta,tp,la+lb<<2);
}

signed main(){
#ifdef zxyoi
	freopen("neural.in","r",stdin);
#endif
	int T=gi();ta[0]=1;init_fac();
	while(T--){
		k=gi();ec=0;memset(el+1,0,k<<2);
		for(int re i=1;i<k;++i)adde(gi(),gi());
		dfs(1,0);for(int re i=1;i<=k;++i)
		ct[i]=mul(T==0?fac[i-1]:fac[i],add(dp[1][i][0],add(dp[1][i][1],dp[1][i][2])));
		for(int re i=1;i<=k;++i){
			int coef=0;
			for(int re j=i;j<=k;++j){
				int val=mul(ct[j],C(j-1,i-1));
				(j-i)&1?Dec(coef,val):Inc(coef,val);
			}
			if(T==0){
				tb[i-1]=coef;
				if(i>1)Dec(tb[i-2],coef);
			}else tb[i]=coef;
		}if(T==0)tb[k]=0;else tb[0]=0;
		for(int re i=0;i<=k;++i)Mul(tb[i],ifc[i]);
		poly_mul(tot+1,k+1);tot+=k;
	}
	int ans=0;
	for(int re i=0;i<=tot;++i)Inc(ans,mul(fac[i],ta[i]));
	cout<<ans<<"\n";
	return 0;
}

猜你喜欢

转载自blog.csdn.net/zxyoi_dreamer/article/details/102750569