【树DP+FFT】CF981H K Paths

版权声明:这是蒟蒻的BLOG,神犇转载也要吱一声哦~ https://blog.csdn.net/Dream_Lolita/article/details/89453237

【题目】
CF
给定一棵 n n 个节点的树,你需要按顺序选择 k k 条路径(可以相同,先后顺序不同方案不同),使得每一条边要么不被覆盖,要么仅被一条路径覆盖,要么被所有 k k 条路径覆盖。求方案数模 998244353 998244353

【解题思路】
首先考虑暴力,我们枚举一条路径,实际上就是要在两个节点的子树中分别选择 k k 个点,同时每个儿子子树中只能选择一个点,但根节点本身可以选择任意次。

于是现在单独考虑一个点怎么处理出这个东西,设它为 f f ,则生成函数就是
i = 1 m ( s i z s o n i + 1 ) \sum_{i=1}^m(siz_{son_i}+1)
其中 s i z siz 表示子树大小, s o n son 表示儿子节点。设这个东西的 x i x^i 项系数为 a i a_i ,则有:
f x = i = 0 m a i P k i f_x=\sum_{i=0}^m a_i\cdot P_{k}^i
不考虑祖先关系,则现在的答案就是:
1 2 ( ( f x ) 2 f x 2 ) \frac 1 2 ((\sum f_x)^2-\sum f_x^2)
上面这部分可以用分治 FFT \text{FFT} 解决。

还要考虑有祖先关系的点对贡献,那么考虑在较浅的节点处进行计算,设其为 v v ,若选择了 v v 的一个儿子 u u 子树中的节点座位另一个端点,那么实际上 v v 对应的生成函数就要乘上KaTeX parse error: Expected group after '\frac' at end of input: … {1+(n-siz_v)x} {1+(siz_u)x}

乘或除以一个二项式的时间都是 O ( n ) O(n) 的,观察到对于子树大小相同的孩子其贡献多项式是一样的,可以一起计算,那么这个总个数是 O ( n ) O(\sqrt n) 级别的。

于是最后复杂度就是 O ( n log 2 n + n n ) O(n\log ^2n +n\sqrt n) 了。

【参考代码】

#include<bits/stdc++.h>
#define pb push_back
using namespace std;

typedef long long ll;
const int N=262333,mod=998244353,G=3,inv2=(mod+1)>>1;

int read()
{
	int ret=0;char c=getchar();
	while(!isdigit(c)) c=getchar();
	while(isdigit(c)) ret=ret*10+(c^48),c=getchar();
	return ret;
}

namespace Math
{
	int fac[N],ifac[N],inv[N];
	int add(int x){return x>=mod?x-mod:x;}
	int sub(int x){return x<0?x+mod:x;}
	void Add(int &x,int y){x=add(x+y);}
	void Sub(int &x,int y){x=sub(x-y);}
	int mul(int x,int y){return 1ll*x*y%mod;}
	int qpow(int x,int y){int res=1;for(;y;y>>=1,x=mul(x,x))if(y&1)res=mul(res,x);return res;}
	int getinv(int x){return qpow(x,mod-2);}
	void initmath()
	{
		fac[0]=1;for(int i=1;i<N;++i)fac[i]=mul(fac[i-1],i);
		ifac[N-1]=getinv(fac[N-1]);for(int i=N-2;~i;--i)ifac[i]=mul(ifac[i+1],i+1);
		inv[0]=inv[1]=1;for(int i=2;i<N;++i)inv[i]=mul(mod-mod/i,inv[mod%i]);
	}
	int P(int x,int y){return 1ll*fac[x]*ifac[x-y]%mod;}
}
using namespace Math;

namespace Poly
{
	int m,L,rev[N];
	void ntt(int *a,int n,int op)
	{
		for(int i=0;i<n;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
		for(int i=1;i<n;i<<=1)
		{
			int wn=qpow(G,(mod-1)/(i<<1));
			if(!~op) wn=getinv(wn);
			for(int j=0;j<n;j+=i<<1)
			{
				int w=1;
				for(int k=0;k<i;++k,w=mul(w,wn))
				{
					int x=a[j+k],y=mul(w,a[i+j+k]);
					a[j+k]=add(x+y);a[i+j+k]=sub(x-y);
				}
			}
		}
		if(!~op)for(int i=0,iv=getinv(n);i<n;++i)a[i]=mul(iv,a[i]);
	}
	void reget(int n)
	{
		for(m=1,L=0;m<n;m<<=1,++L);
		for(int i=0;i<m;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
	}
	void polymul(int *a,int *b,int *c)
	{
		//for(int i=0;i<m;++i) printf("%d ",a[i]); puts("");
		//for(int i=0;i<m;++i) printf("%d ",b[i]); puts("");
		ntt(a,m,1);ntt(b,m,1);
		for(int i=0;i<m;++i) c[i]=mul(a[i],b[i]);
		ntt(c,m,-1);
		//for(int i=0;i<m;++i) printf("%d ",c[i]); puts("");
	}
	void polymult(int *a,int *b,int *c,int dega,int degb)
	{
		static int A[N],B[N];
		reget(dega+degb-1);copy(a,a+dega,A);copy(b,b+degb,B);
		//printf("degs:%d %d %d\n",dega,degb,m);
		polymul(A,B,c);
		//for(int i=0;i<m;++i) printf("%d ",c[i]); puts("");
		fill(c+dega+degb-1,c+m,0);fill(A,A+m,0);fill(B,B+m,0);
	}
	void polydec(int *a,int deg,int v)
	{
		static int A[N];
		int coe=getinv(v),iv;
		for(int i=0;i<deg;++i) A[i]=a[i],a[i]=0;
		for(int i=deg-1;i;--i)
		{
			if(A[i])
			{
				a[i-1]=iv=mul(A[i],coe);
				Sub(A[i],mul(iv,coe));Sub(A[i-1],iv);
			}
		}
		fill(A,A+deg,0);
	}
	void polyadd(int *a,int deg,int v)
	{
		for(int i=deg;i;--i) Add(a[i],mul(a[i-1],v));
	}
}
using namespace Poly;

namespace DreamLolita
{
	int n,K,tot,ans;
	int head[N],siz[N],now[N],val[N],tmp[N];
	int f[N],g[N],h[N],F[20][N];
	struct Tway{int v,nex;}e[N];
	void add(int u,int v)
	{
		e[++tot]=(Tway){v,head[u]};head[u]=tot;
		e[++tot]=(Tway){u,head[v]};head[v]=tot;
	}
	void solve(int l,int r,int d)
	{
		if(l==r){F[d][1]=val[l];F[d][0]=1;return;}
		int mid=(l+r)>>1;
		solve(l,mid,d);solve(mid+1,r,d+1);
		polymult(F[d],F[d+1],F[d],mid-l+2,r-mid+1);
		fill(F[d+1],F[d+1]+Poly::m,0);
	}
	int calc(int *a,int len)
	{
		int res=0,lim=min(len,K);
		for(int i=0;i<=lim;++i) Add(res,mul(a[i],P(K,i)));
		return res;
	}
	bool cmp(int x,int y){return siz[x]<siz[y];}
	void dfs1(int x,int fa)
	{
		siz[x]=1;int son=0;
		for(int i=head[x];i;i=e[i].nex)
		{
			int v=e[i].v;
			if(v==fa) continue;
			dfs1(v,x);Add(g[x],g[v]);siz[x]+=siz[v];
		}
		for(int i=head[x];i;i=e[i].nex) if(e[i].v!=fa) val[++son]=siz[e[i].v];;
		if(!son) {f[x]=g[x]=1;return;}
		solve(1,son,0);
		//printf("%d:\n",x);
		//for(int i=1;i<=son;++i) printf("%d!",val[i]); puts("");
		//for(int i=0;i<=son;++i) printf("%d ",F[0][i]); puts("");
		f[x]=calc(F[0],son);Add(g[x],f[x]);son=0;
		//printf("%d:%d\n",x,f[x]);

		for(int i=head[x];i;i=e[i].nex) if(e[i].v!=fa) val[++son]=e[i].v;
		sort(val+1,val+son+1,cmp);
		for(int i=0;i<=son;++i) tmp[i]=F[0][i];
		for(int i=1;i<=son;++i)
		{
			if(siz[val[i]]==siz[val[i-1]]) h[val[i]]=h[val[i-1]];
			else
			{
				for(int j=0;j<=son;++j) now[j]=tmp[j];
				polydec(now,son+1,siz[val[i]]);polyadd(now,son,n-siz[x]);
				h[val[i]]=calc(now,son);
			}
			//printf("%d %d\n",val[i],h[val[i]]);
		}
		fill(F[0],F[0]+Poly::m,0);
		for(int i=0;i<=son;++i) now[i]=tmp[i]=0;
	}
	void dfs2(int x,int fa)
	{
		for(int i=head[x];i;i=e[i].nex)
		{
			int v=e[i].v;
			if(v==fa) continue;
			Add(ans,mul(sub(h[v]-f[x]),g[v]));dfs2(v,x);
		}
	}
	void solution()
	{
		initmath();n=read();K=read();
		if(K==1){printf("%lld\n",1ll*n*(n-1)/2%mod);return;}
		for(int i=1;i<n;++i) add(read(),read());
		dfs1(1,0);dfs2(1,0);
		//for(int i=1;i<=n;++i) printf("%d %d %d\n",f[i],g[i],h[i]);
		//printf("%d\n",ans);
		int sum=0;
		for(int i=1;i<=n;++i) Add(sum,f[i]);
		sum=mul(sum,sum);
		for(int i=1;i<=n;++i) Sub(sum,mul(f[i],f[i]));
		sum=mul(sum,inv2);Add(ans,sum);
		printf("%d\n",ans);
	}
}

int main()
{
#ifdef Durant_Lee
	freopen("CF981H.in","r",stdin);
	freopen("CF981H.out","w",stdout);
#endif
	DreamLolita::solution();
	return 0;
}

猜你喜欢

转载自blog.csdn.net/Dream_Lolita/article/details/89453237