版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/baidu_36797646/article/details/85626878
题解:
考虑最简单的DP,
表示
个数的排列,上升序列长度为
的方案数,考虑从大到小放数字,即最后一次放最小的数字,容易得到转移
,这其实是第一类斯特林数的递推式。
枚举n放在哪里,答案为
。
但是这样做不够,考虑再化一下式子,定义一块为每个上升序列的数到下一个比它大的数之间的数,比如排列
,第一个和最后一个数分别为一块,中间三个数为一块,那么答案实际上就是
,也就是在
块中选出
块放到
前面,其它放到
后面。
这样我们只需要求一个斯特林数就可以了。这里用到第一类斯特林数一个性质,即
为
的
次上升幂的
次项系数,用式子表示就是
感性理解一下上面这个东西的话就是最后一个
,要么是
贡献到前面已经有的
,要么是
贡献到前面的
。
然后用分治+NTT求出这个多项式的这一项就完成了。
代码:
#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pa pair<int,int>
const int Maxn=100010;
const int inf=2147483647;
const int mod=998244353,gn=3;
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*f;
}
int n,A,B,rev[Maxn<<2],P[Maxn<<2],P1[Maxn<<2],P2[Maxn<<2];
int Pow(int x,int y)
{
if(!y)return 1;
int t=Pow(x,y>>1),re=(LL)t*t%mod;
if(y&1)re=(LL)re*x%mod;
return re;
}
void ntt(int *a,int n,int op)
{
for(int i=0;i<n;i++)
if(i<rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1)
{
int wn;
if(op==1)wn=Pow(gn,(mod-1)/(i<<1));
else wn=Pow(gn,mod-1-(mod-1)/(i<<1));
for(int j=0;j<n;j+=(i<<1))
{
int w=1;
for(int k=0;k<i;k++)
{
int t=(LL)a[i+j+k]*w%mod;w=(LL)w*wn%mod;
a[i+j+k]=(a[j+k]-t+mod)%mod;
a[j+k]=(a[j+k]+t)%mod;
}
}
}
if(op==-1)
{
int inv=Pow(n,mod-2);
for(int i=0;i<n;i++)a[i]=(LL)a[i]*inv%mod;
}
}
vector<int>S[Maxn<<1];int tot=0;
int solve(int l,int r)
{
int T=++tot;
if(l==r){S[T].push_back(l),S[T].push_back(1);return T;}
int mid=l+r>>1;
int L=solve(l,mid),R=solve(mid+1,r);
int N=1;
while(N<=(mid-l+1)*2)N<<=1;
for(int i=0;i<N;i++)P1[i]=P2[i]=0;
for(int i=0;i<S[L].size();i++)P1[i]=S[L][i];
for(int i=0;i<S[R].size();i++)P2[i]=S[R][i];
rev[0]=0;
for(int i=1;i<N;i++)rev[i]=((rev[i>>1]>>1)|((i&1)*(N>>1)));
ntt(P1,N,1),ntt(P2,N,1);
for(int i=0;i<N;i++)P[i]=(LL)P1[i]*P2[i]%mod;
ntt(P,N,-1);
for(int i=0;i<=r-l+1;i++)S[T].push_back(P[i]);
return T;
}
int C(int a,int b)
{
if(a<b)return 0;
int re=1;
for(int i=a;i>=a-b+1;i--)re=(LL)re*i%mod;
for(int i=b;i;i--)re=(LL)re*Pow(i,mod-2)%mod;
return re;
}
int main()
{
n=read(),A=read(),B=read();
if(A+B-2>n-1||!A||!B)return puts("0"),0;
if(n==1)return puts("1"),0;
solve(0,n-2);
printf("%d",(LL)S[1][A+B-2]*C(A+B-2,A-1)%mod);
}