Codeforces 960G. Bandit Blues 分治NTT+第一类斯特林数+DP

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/baidu_36797646/article/details/85626878

题解:

考虑最简单的DP, f i , j f_{i,j} 表示 i i 个数的排列,上升序列长度为 j j 的方案数,考虑从大到小放数字,即最后一次放最小的数字,容易得到转移 f i , j = f i 1 , j 1 + ( i 1 ) f i 1 , j f_{i,j}=f_{i-1,j-1}+(i-1)f_{i-1,j} ,这其实是第一类斯特林数的递推式。
枚举n放在哪里,答案为 i = 1 n f i 1 , A 1 × f n i , B 1 × C n 1 i 1 \sum_{i=1}^nf_{i-1,A-1}\times f_{n-i,B-1}\times C_{n-1}^{i-1}
但是这样做不够,考虑再化一下式子,定义一块为每个上升序列的数到下一个比它大的数之间的数,比如排列 1   4   3   2   5 1\ 4\ 3\ 2\ 5 ,第一个和最后一个数分别为一块,中间三个数为一块,那么答案实际上就是 f n 1 , A + B 2 × C A + B 2 A 1 f_{n-1,A+B-2}\times C_{A+B-2}^{A-1} ,也就是在 A + B 2 A+B-2 块中选出 A 1 A-1 块放到 n n 前面,其它放到 n n 后面。
这样我们只需要求一个斯特林数就可以了。这里用到第一类斯特林数一个性质,即 S ( n , m ) S(n,m) x x n n 次上升幂的 m m 次项系数,用式子表示就是 x n = x ( x + 1 ) ( x + 2 ) . . . ( x + n 1 ) = i = 0 n 1 S ( n , i ) x i x^{n\uparrow}=x(x+1)(x+2)...(x+n-1)=\sum_{i=0}^{n-1}S(n,i)x^i
感性理解一下上面这个东西的话就是最后一个 ( x + i 1 ) (x+i-1) ,要么是 i 1 i-1 贡献到前面已经有的 x j x^j ,要么是 x x 贡献到前面的 x j 1 x^{j-1}
然后用分治+NTT求出这个多项式的这一项就完成了。

代码:

#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pa pair<int,int>
const int Maxn=100010;
const int inf=2147483647;
const int mod=998244353,gn=3;
int read()
{
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
	return x*f;
}
int n,A,B,rev[Maxn<<2],P[Maxn<<2],P1[Maxn<<2],P2[Maxn<<2];
int Pow(int x,int y)
{
	if(!y)return 1;
	int t=Pow(x,y>>1),re=(LL)t*t%mod;
	if(y&1)re=(LL)re*x%mod;
	return re;
}
void ntt(int *a,int n,int op)
{
	for(int i=0;i<n;i++)
	if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int i=1;i<n;i<<=1)
	{
		int wn;
		if(op==1)wn=Pow(gn,(mod-1)/(i<<1));
		else wn=Pow(gn,mod-1-(mod-1)/(i<<1));
		for(int j=0;j<n;j+=(i<<1))
		{
			int w=1;
			for(int k=0;k<i;k++)
			{
				int t=(LL)a[i+j+k]*w%mod;w=(LL)w*wn%mod;
				a[i+j+k]=(a[j+k]-t+mod)%mod;
				a[j+k]=(a[j+k]+t)%mod;
			}
		}
	}
	if(op==-1)
	{
		int inv=Pow(n,mod-2);
		for(int i=0;i<n;i++)a[i]=(LL)a[i]*inv%mod;
	}
}
vector<int>S[Maxn<<1];int tot=0;
int solve(int l,int r)
{
	int T=++tot;
	if(l==r){S[T].push_back(l),S[T].push_back(1);return T;}
	int mid=l+r>>1;
	int L=solve(l,mid),R=solve(mid+1,r);
	int N=1;
	while(N<=(mid-l+1)*2)N<<=1;
	for(int i=0;i<N;i++)P1[i]=P2[i]=0;
	for(int i=0;i<S[L].size();i++)P1[i]=S[L][i];
	for(int i=0;i<S[R].size();i++)P2[i]=S[R][i];
	rev[0]=0;
	for(int i=1;i<N;i++)rev[i]=((rev[i>>1]>>1)|((i&1)*(N>>1)));
	ntt(P1,N,1),ntt(P2,N,1);
	for(int i=0;i<N;i++)P[i]=(LL)P1[i]*P2[i]%mod;
	ntt(P,N,-1);
	for(int i=0;i<=r-l+1;i++)S[T].push_back(P[i]);
	return T;
}
int C(int a,int b)
{
	if(a<b)return 0;
	int re=1;
	for(int i=a;i>=a-b+1;i--)re=(LL)re*i%mod;
	for(int i=b;i;i--)re=(LL)re*Pow(i,mod-2)%mod;
	return re;
}
int main()
{
	n=read(),A=read(),B=read();
	if(A+B-2>n-1||!A||!B)return puts("0"),0;
    if(n==1)return puts("1"),0;
    solve(0,n-2);
    printf("%d",(LL)S[1][A+B-2]*C(A+B-2,A-1)%mod);
}

猜你喜欢

转载自blog.csdn.net/baidu_36797646/article/details/85626878