[集训队作业2018]uoj 449 喂鸽子 - min-max容斥 - dp - NTT

题目大意:有n个数字一开始全0,每次随机一个数字++,问期望多少步后第一次有个位置的数字的值是k。 n 50 , k 1000 n\le50,k\le1000
题解:
显然k=1是个min-max容斥。因此min-max容斥。假设枚举的集合大小是 a a
一开始场上的做法是,期望转为 i 1 P ( a n s i ) \sum_{i\ge1}P(ans\ge i) ,然后 P ( a n s i ) P(ans\ge i) 就是说到 i 1 i-1 的时候还没好的概率,然后再枚举这i-1次操作中多少次落在左半边,整理式子后发现有个组合数数列点积等比数列的求和,总之推导一波后发现要计算大小为a的集合i步后不存在>=k的数值的方案数,发现这玩意只能 n 2 k 2 n^2k^2 dp,但是是卷积,场上写个NTT过了。
另一种做法是直接枚举: i 1 P ( a n s = i ) i \sum_{i\ge 1}P(ans=i)i ,然后还是枚举左边的次数j,右边还是组合数数列点积等比数列的求和(其实算出来就是 n a \frac na ),推导一下发现要算大小为a的集合i步后恰好有个位置是k的方案数,这个就用钦定某个数第i步是出现恰好k次的总方案数减去不合法的位置。后者用个前缀和即可。其实就是推式子很麻烦,但直接冷静思考一波就是a的答案乘以 n a \frac na 的比率即可。

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define p 998244353
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
inline int inn()
{
	int x,ch;while((ch=gc)<'0'||ch>'9');
	x=ch^'0';while((ch=gc)>='0'&&ch<='9')
		x=(x<<1)+(x<<3)+(ch^'0');return x;
}
const int N=52,K=1002;
int f[2][N*K],fac[N*K],facinv[N*K],inv[N*K],mi1[N*K],mi2[N*K];
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%p) (k&1)?ans=(lint)ans*x%p:0;return ans; }
inline int C(int n,int m) { return assert(n>=m&&m>=0),(lint)fac[n]*facinv[m]%p*facinv[n-m]%p; }
inline int prelude(int n)
{
	rep(i,fac[0]=1,n) fac[i]=(lint)i*fac[i-1]%p;
	facinv[n]=fast_pow(fac[n],p-2);
	for(int i=n-1;i>=0;i--) facinv[i]=(i+1ll)*facinv[i+1]%p;
	rep(i,1,n) inv[i]=(lint)fac[i-1]*facinv[i]%p;
	return 0;
}
int main()
{
	int n=inn(),k=inn(),ans=0;prelude(n*k);
	rep(i,1,n)
	{
		int *now=f[i&1],*pre=f[(i-1)&1],s=0;
		rep(j,k,(i-1)*(k-1)+1) pre[j]=(pre[j]+(i-1ll)*pre[j-1])%p;
		rep(j,mi1[0]=1,i*(k-1)+1) mi1[j]=mi1[j-1]*(i-1ll)%p;
		rep(j,mi2[0]=1,i*(k-1)+2) mi2[j]=(lint)mi2[j-1]*inv[i]%p;
		rep(j,k,i*(k-1)+1)
			now[j]=(lint)i*C(j-1,k-1)%p*(mi1[j-k]-pre[j-k])%p,
			(now[j]<0?now[j]+=p:0),s=(s+(lint)now[j]*j%p*mi2[j+1])%p;
		if(i&1) ans=(ans+(lint)C(n,i)*s)%p;
		else ans=(ans-(lint)C(n,i)*s)%p,(ans<0?ans+=p:0);
	}
	return !printf("%lld\n",(lint)ans*n%p);
}
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define p 998244353
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
inline int inn()
{
	int x,ch;while((ch=gc)<'0'||ch>'9');
	x=ch^'0';while((ch=gc)>='0'&&ch<='9')
		x=(x<<1)+(x<<3)+(ch^'0');return x;
}
const int N=52,K=1005,QWQ=67000;
int dp[N][QWQ],fac[N*K],facinv[N*K],mi[N*K],tmp[QWQ];
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%p) (k&1)?ans=(lint)ans*x%p:0;return ans; }
inline int sol(int x,int s) { return (s&1)?(x?p-x:0):x; }
inline int C(int n,int m) { return (lint)fac[n]*facinv[m]%p*facinv[n-m]%p; }
namespace NTT_space{
	const int N=67000;
	int r[N],*dwg[N],*dwgi[N];
	inline int prelude_dwg()
	{
		int n=N-2;
		for(int i=2,t=1;i<=n;i<<=1,t++)
		{
			dwg[t]=new int[i],dwgi[t]=new int[i];
			int *d=dwg[t],*di=dwgi[t];d[0]=di[0]=1;
			int w=fast_pow(3,(p-1)/i),wi=fast_pow(3,p-1-(p-1)/i);
			rep(j,1,i-1) d[j]=(lint)d[j-1]*w%p,di[j]=(lint)di[j-1]*wi%p;
		}
		return 0;
	}
	inline int pre(int m)
	{
		int n=1,L=0;
		while(n<=m) n<<=1,L++;
		rep(i,1,n-1) r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
		return n;
	}
	inline int NTT(int *a,int n,int s)
	{
		rep(i,1,n-1) if(i<r[i]) swap(a[i],a[r[i]]);
		for(int i=2,c=1;i<=n;i<<=1,c++)
		{
			int *d=(s>0?dwg[c]:dwgi[c]);
			for(int j=0,t=i>>1,x,y;j<n;j+=i) rep(k,0,t-1)
				x=a[j+k],y=(lint)d[k]*a[j+k+t]%p,
				a[j+k]=x+y,(a[j+k]>=p?a[j+k]-=p:0),
				a[j+k+t]=x-y,(a[j+k+t]<0?a[j+k+t]+=p:0);
		}
		if(s<0) for(int i=0,v=fast_pow(n,p-2);i<n;i++) a[i]=(lint)a[i]*v%p;
		return 0;
	}
}using NTT_space::NTT;
inline int prelude(int n,int k)
{
	dp[0][0]=1;int m=max(n,n*(k-1));
	rep(i,fac[0]=1,m) fac[i]=(lint)fac[i-1]*i%p;
	facinv[m]=fast_pow(fac[m],p-2);
	for(int i=m-1;i>=0;i--) facinv[i]=(i+1ll)*facinv[i+1]%p;
	int t=NTT_space::pre(n*(k-1));
	NTT(dp[0],t,1);rep(i,0,k-1) tmp[i]=facinv[i];NTT(tmp,t,1);
	rep(i,1,n) rep(j,0,t-1) dp[i][j]=(lint)dp[i-1][j]*tmp[j]%p;
	rep(i,1,n) NTT(dp[i],t,-1);
	rep(i,1,n) rep(j,0,i*(k-1)) dp[i][j]=(lint)dp[i][j]*fac[j]%p;
	return 0;
}
int main()
{
	NTT_space::prelude_dwg();
	int n=inn(),k=inn(),ans=0;prelude(n,k);
	rep(i,1,n)
	{
		int s=0,v=fast_pow(i,p-2);
		rep(j,mi[0]=1,i*(k-1)+1)
			mi[j]=(lint)mi[j-1]*v%p;
		rep(j,0,i*(k-1))
			s=(s+(lint)mi[j+1]*dp[i][j])%p;
		ans+=sol((lint)C(n,i)*s%p,i+1);
		if(ans>=p) ans-=p;
	}
	return !printf("%lld\n",(lint)ans*n%p);
}

猜你喜欢

转载自blog.csdn.net/Mys_C_K/article/details/88165113