[Codeforces755G][DP][NTT]PolandBall and Many Other Balls

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Rose_max/article/details/82946384

翻译

给你n个球,把他们分成K组,允许有的球没有组
每组不能为空也不能超过两个球
求方案数
n<=1e9 K<=2^15

题解

f [ i ] [ j ] f[i][j] 表示前 i i 个球分成 j j 组的方案数
朴素DP容易想到
f [ i ] [ j ] = f [ i 1 ] [ j ] + f [ i 1 ] [ j 1 ] + f [ i 2 ] [ j 1 ] f[i][j]=f[i-1][j]+f[i-1][j-1]+f[i-2][j-1]
优化不了… 没辙
换一个转移方式
对于 2 n 2*n ,可以由 n n 得到
显然有
f [ 2 i ] [ K ] = j = 0 K f [ i ] [ j ] f [ i ] [ K j ] f[2*i][K]=\sum_{j=0}^{K}f[i][j]*f[i][K-j]
发现漏算了一种情况,在分界点的两个球组成一组的情况
所以加上
f [ 2 i ] [ K ] + = j = 0 K 1 f [ i 1 ] [ j ] f [ i 1 ] [ K 1 j ] f[2*i][K]+=\sum_{j=0}^{K-1}f[i-1][j]*f[i-1][K-1-j]
这个可以优化了
上面两个是卷积形式可以NTT优化成 n log n n\log n
倍增DP
维护 f [ i ] , f [ i 1 ] , f [ i 2 ] f[i],f[i-1],f[i-2] 的多项式,用系数当答案

f [ 2 i ] = f [ i ] f [ i ] + f [ i 1 ] f [ i 1 ] f[2*i]=f[i]*f[i]+f[i-1]*f[i-1]
f [ 2 i 1 ] = f [ i ] f [ i 1 ] + f [ i 1 ] f [ i 2 ] f[2*i-1]=f[i]*f[i-1]+f[i-1]*f[i-2]
f [ 2 i 2 ] = f [ i 1 ] f [ i 1 ] + f [ i 2 ] f [ i 2 ] f[2*i-2]=f[i-1]*f[i-1]+f[i-2]*f[i-2]
前面取第K项 后面取第K-1项
类似二进制一样维护,如果扫到这一位有1,那么暴力把这个1的贡献加上
相当于
n=1001000
1000->1001->10010…
暴力做一遍的复杂度是 O ( n ) O(n)
倍增复杂度 log n \log n
NTT的复杂度 n log n n\log n
所以复杂度完美的 n log 2 n n\log^2 n
NTT做的有点多…常数大了有兴趣的可以帮我卡卡啊…

#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
#include<queue>
#include<vector>
#include<ctime>
#include<map>
#define LL long long
#define mp(x,y) make_pair(x,y)
#define mod 998244353
#define MAXN 50005
using namespace std;
inline int read()
{
	int f=1,x=0;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline void write(int x)
{
	if(x<0)putchar('-'),x=-x;
	if(x>9)write(x/10);
	putchar(x%10+'0');
}
inline void print(int x){write(x);printf(" ");}
LL pow_mod(LL a,LL b)
{
	LL ret=1;
	while(b)
	{
		if(b&1)ret=ret*a%mod;
		a=a*a%mod;b>>=1;
	}
	return ret;
}
int R[MAXN*4],L;
void NTT(LL *y,int len,int on)
{
	for(int i=0;i<len;i++)if(i<R[i])swap(y[i],y[R[i]]);
	for(int i=1;i<len;i<<=1)
	{
		LL wn=pow_mod(3,(mod-1)/(i<<1));if(on==-1)wn=pow_mod(wn,mod-2);
		for(int j=0;j<len;j+=(i<<1))
		{
			LL w=1;
			for(int k=0;k<i;k++)
			{
				LL u=y[j+k];
				LL v=y[j+k+i]*w%mod;
				y[j+k]=(u+v)%mod;
				y[j+k+i]=(u-v+mod)%mod;
				w=w*wn%mod;
			}
		}
	}
	if(on==-1)
	{
		LL tmp=pow_mod(len,mod-2);
		for(int i=0;i<len;i++)y[i]=y[i]*tmp%mod;
	}
}
LL A[MAXN*4],B[MAXN*4],C[MAXN*4];
LL n1[MAXN*4],n2[MAXN*4],n3[MAXN*4];
LL s1[MAXN*4],s2[MAXN*4],s3[MAXN*4],s4[MAXN*4],s5[MAXN*4];
int n,K;
void update(int len)
{
	memcpy(n1,A,sizeof(n1));memcpy(n2,B,sizeof(n2));memcpy(n3,C,sizeof(n3));
	NTT(n1,len,1);NTT(n2,len,1);NTT(n3,len,1);
	for(int i=0;i<len;i++)s1[i]=n1[i]*n1[i]%mod;
	for(int i=0;i<len;i++)s2[i]=n2[i]*n2[i]%mod;
	for(int i=0;i<len;i++)s3[i]=n1[i]*n2[i]%mod;
	for(int i=0;i<len;i++)s4[i]=n2[i]*n3[i]%mod;
	for(int i=0;i<len;i++)s5[i]=n3[i]*n3[i]%mod;
	NTT(s1,len,-1);NTT(s2,len,-1);NTT(s3,len,-1);NTT(s4,len,-1);NTT(s5,len,-1);
	for(int i=1;i<=K;i++)A[i]=(s1[i]+s2[i-1])%mod;A[0]=s1[0];
	for(int i=1;i<=K;i++)B[i]=(s3[i]+s4[i-1])%mod;B[0]=s3[0];
	for(int i=1;i<=K;i++)C[i]=(s2[i]+s5[i-1])%mod;C[0]=s2[0];
}
LL tmp[MAXN*4],t1[MAXN*4];
void vio(int ok)
{
	memcpy(tmp,A,sizeof(tmp));
	for(int i=1;i<=min(K,ok);i++)A[i]=(A[i]+tmp[i-1]+B[i-1])%mod;
	memcpy(t1,B,sizeof(t1));memcpy(B,tmp,sizeof(B));memcpy(C,t1,sizeof(C));
}
int gets(int u){int ret=0;for(;u;u>>=1)ret++;return ret;}
int main()
{
	//freopen("a.in","r",stdin);
	//freopen("b.out","w",stdout);
	n=read();K=read();
	int ln=1;
	for(ln=1;ln<=2*K;ln<<=1)L++;
	for(int i=0;i<ln;i++)R[i]=(R[i>>1]>>1)|(i&1)<<(L-1);
	A[0]=1;
	int lg=gets(n);int sum=0;
	for(int i=lg-1;i>=0;i--)
	{
		update(ln);sum<<=1;
		if(sum==2)C[1]=0;
		if(n&(1<<i))sum|=1,vio(sum);
	}
	for(int i=1;i<=K;i++)printf("%lld ",A[i]);
	puts("");
	return 0;
}

猜你喜欢

转载自blog.csdn.net/Rose_max/article/details/82946384