[BZOJ5292][概率DP]BJOI2018:治疗之雨

版权声明:虽然博主很菜,但是还是请注明出处(我觉得应该没人偷我的博客) https://blog.csdn.net/qq_43346903/article/details/88240411

BZOJ5292

实在是对概率不感冒
膜了一下学姐scarlyw的题解
不过学姐的 f f 那里写错了吧,第二个括号里应该是 1 m + 1 \frac{1}{m+1}

Code:

#include<bits/stdc++.h>
using namespace std;
inline int read(){
	int res=0,f=1;char ch=getchar();
	while(!isdigit(ch)) {if(ch=='-') f=-f;ch=getchar();}
	while(isdigit(ch)) {res=(res<<1)+(res<<3)+(ch^48);ch=getchar();}
	return res*f;
}
const int N=1505,mod=1e9+7;
inline int qpow(int a,int b){
	int res=1;for(;b;b>>=1){if(b&1) res=1ll*res*a%mod;a=1ll*a*a%mod;}
	return res;
}
int n,p,m,k,inv[N],f[N][N],pw[N],C[N],a[N];
inline void work(){
	int i,j,over,pp,tmp,hp;
	n=read();p=read();m=read();k=read();
	pp=qpow(m+1,mod-2);
	over=qpow(qpow(m+1,k),mod-2);
	C[0]=1;
	for(int i=1;i<=n;i++) C[i]=1ll*C[i-1]*inv[i]%mod*(k-i+1)%mod;
	for(int i=0;i<=min(n,k);i++)
		pw[i]=1ll*qpow(m,k-i)*C[i]%mod*over%mod;
	f[0][0]=1;f[0][n+1]=0;
	for(int i=1;i<=n;i++) f[0][i]=0;
	for(int i=1;i<=n;i++){
		for(int j=0;j<=n+1;j++) f[i][j]=0;
		f[i][n+1]=1; f[i][i]=hp=1;
		for(int j=0;j<=i;j++){
			tmp=j<=k?1ll*pw[j]*(i==n?0:pp)%mod:0;
			f[i][i-j+1]-=tmp; hp-=tmp;
			if(f[i][i-j+1]<0) f[i][i-j+1]+=mod;
			if(hp<0) hp+=mod;
		}
		if(i<n) f[i][0]-=1ll*hp*pp%mod;
		if(f[i][0]<0) f[i][0]+=mod;
		hp=1;
		for(int j=0;j<=i-1;j++){
			tmp=j<=k?1ll*pw[j]*(i==n?1:1ll*m*pp%mod)%mod:0;
			f[i][i-j]-=tmp; hp-=tmp;
			if(f[i][i-j]<0) f[i][i-j]+=mod;
			if(hp<0) hp+=mod;
		}
		f[i][0]-=i==n?hp:1ll*hp*m%mod*pp%mod;
		if(f[i][0]<0) f[i][0]+=mod;
	}
	for(int i=n;i>=2;i--){
		if(!f[i][i]) return(void) puts("-1");
		int tmp=qpow(f[i][i],mod-2);
		for(int j=0;j<=n+1;j++) f[i][j]=1ll*f[i][j]*tmp%mod;
		tmp=f[i-1][i];
		for(int j=0;j<=n+1;j++){
			f[i-1][j]-=1ll*f[i][j]*tmp%mod;
			if(f[i-1][j]<0) f[i-1][j]+=mod;
		}
	}
	if(!f[1][1]) return(void) puts("-1");
	tmp=qpow(f[1][1],mod-2);
	for(int i=0;i<=n+1;i++) f[1][i]=1ll*f[1][i]*tmp%mod;
	a[0]=0;
	for(int i=1;i<=p;i++){
		a[i]=f[i][n+1];
		for(int j=0;j<=i-1;j++){
			a[i]-=1ll*f[i][j]*a[j]%mod;
			if(a[i]<0) a[i]+=mod;
		}
	}
	cout<<a[p]<<"\n";
}
int main(){
	int i,t=read();
	inv[1]=1;
	for(int i=2;i<=1500;i++) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
	while(t--) work();
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_43346903/article/details/88240411