CF960G Bandit Blues

Description

计算 \(1...n\) 的全排列中不同的前缀最大值的个数等于 \(a\) 且不同的后缀最大值的个数等于 \(b\) 的排列的数目对 \(998244353\) 取模的结果

\(1\le n\le 10^5,0\le a,b\le n​\)

Solution

画一画合法的排列可以分析得出以下信息

  • 无论是前缀最大值还是后缀最大值,新出现的最大值 \(k​\) 总是会以 \(\{k,i_1,i_2,...,i_c\}​\) 的形式出现,其中 \(i_1,i_2,...,i_c​\) 为任意 \(c​\) 个小于 \(k​\) 的数,\(c​\) 可以为 \(0​\)

  • 整个排列的最大值,即 \(n​\),会将这个序列分为前后两个排列,而这两个排列是独立的

  • 由于我们只关注排列的相对大小关系,所以被处在下标 \(i​\) 位置的 \(n​\) 划分出的两个小排列可以直接看作 \(\{1,2,...,i-1\}​\)\(\{1,2,...,n-i\}​\) 两组元素的排列

所以我们只需要求 \(1...i\) 的全排列中不同的前缀最大值的个数等于 \(s\) 的排列个数即可,\(i\in [0,n],s\in[1,n]\)

从第一条性质我们可以看出,如果我们将 \(n\) 个元素划分为 \(j\) 个集合,并默认其中的最大值为 \(k\) ,并按每个集合的 \(k\) 将这 \(j\) 个集合从小到大排序,那么这一定是一种符合条件的方案,并且这种构造方法可以包括所有的合法方案

但是如果这样用第二类斯特林数做的话,对于每一个大小为 \(s\) 的集合,我们还需乘上 \((s-1)!\),这样就不太好做了

实际上第一类斯特林数可以很好地解决这个问题,因为它本质上是枚举将长度为 \(n\) 的序列分解为 \(i\) 个循环的方案数,符合我们的要求

所以我们最终要求的是 \(\begin{bmatrix}i\\a-1\end{bmatrix}\)\(\begin{bmatrix}i\\b-1\end{bmatrix}\),其中 \(i\in [0,n-1]\)

答案就是 \(\text{ans}=\sum\limits_{i=1}^{n}\begin{bmatrix}i-1\\a-1\end{bmatrix}\begin{bmatrix}n-i\\b-1\end{bmatrix}​\)

那么怎么快速求第一类斯特林数呢?

我们知道 \(\begin{bmatrix}n\\i\end{bmatrix}\) 的生成函数是 \(x^{\overline{n}}\),但这只能快速求一行的第一类斯特林数,对于求一列的话,复杂度就退化成 \(O(n^2\log n)\)

然后发现没有快速求一列第一类斯特林数的方法。。。

从另一个角度思考问题,不如忽略掉第三条性质,直接生成 \(a+b-2\) 个循环再分配到两边,这样也是对的

那么答案就是 \(\text{ans}=\begin{bmatrix}n-1\\a+b-2\end{bmatrix}\dbinom{a+b-2}{a-1}\)

组合数很好求,下面具体说说如何求第一类斯特林数

因为 \(\begin{bmatrix}n\\i\end{bmatrix}=[x^i](x(x+1)(x+2)...(x+n-1))\) ,所以关键是如何求这 \(n\) 个多项式的卷积

考虑倍增

假设现在已经求出了 \(x(x+1)(x+2)...(x+k-1)\) ,需要求 \(x(x+1)(x+2)...(x+k-1)(x+k)(x+k+1)...(x+2k-1)\)

显然后一半的式子可以通过将 \(x+k\) 带入前一半的式子得到,然后再把两个式子 \(\text{NTT}\) 一下就可以得到我们想要的式子了

怎么带入呢?考虑二项式定理,那么对于 \(p\) 次项,它的系数为
\[ \sum\limits_{i=p}^{k}a_i\dbinom{i}{p}k^{i-p}\tag{1} \]
其中 \(a_i\) 为带入前第 \(i\) 次项的系数

展开这个式子,得到
\[ \frac{1}{p!}\sum\limits_{i=p}^{k}a_ii!\frac{k^{i-p}}{(i-p)!}\tag{2} \]
定义 \(f_i=a_ii!,g_i=\frac{k^i}{i!}\),那么带入后第 \(p\) 次项系数 \(a'_p\)
\[ a'_p=\frac{1}{p!}\sum\limits_{i=p}^{k}f_ig_{i-p}\tag{3} \]
\(g_i\) 中的值全部翻转,得 \(g'_i\),那么
\[ a'_p=\frac{1}{p!}\sum\limits_{i=0}^{k-p}f_{k-i}g'_{p+i}\tag{4} \]
也可以写成
\[ a'_p=\frac{1}{p!}\sum\limits_{i=0}^{k+p}f_ig'_{k+p-i}\tag{5} \]
然后就能 \(\text{NTT}​\) 求出系数了

复杂度 \(O(n\log n)\)

代码如下:

#include<cstdio>
#include<iostream>
#include<algorithm>
using namespace std;
const int N=1e5+10;
const int mod=998244353;
const int G=3;
const int invG=332748118;
int n,A,B,fac[N<<1],inv[N<<1],f[N<<2],g[N<<2],a[N<<2],b[N<<2],k,now,INV;
inline void Preprocess(){
    fac[0]=1;for(register int i=1;i<=(n<<1);i++)fac[i]=1ll*fac[i-1]*i%mod;
    inv[0]=inv[1]=1;for(register int i=2;i<=(n<<1);i++)inv[i]=(-1ll*mod/i*inv[mod%i]%mod+mod)%mod;
    for(register int i=2;i<=(n<<1);i++)inv[i]=1ll*inv[i-1]*inv[i]%mod;
}
inline int C(int n,int m){if(n<0||m<0||n<m)return 0;return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;}
inline int fas(int x,int p){int res=1;while(p){if(p&1)res=1ll*res*x%mod;p>>=1;x=1ll*x*x%mod;}return res;}
inline int MOD(int x){x-=x>=mod? mod:0;return x;}
inline void NTT(int *a,int f){
    for(register int i=0,j=0;i<k;i++){
        if(i>j)swap(a[i],a[j]);
        for(register int l=k>>1;(j^=l)<l;l>>=1);}
    for(register int i=1;i<k;i<<=1){
        int w=fas(~f? G:invG,(mod-1)/(i<<1));
        for(register int j=0;j<k;j+=(i<<1)){
            int e=1;
            for(register int p=0;p<i;p++,e=1ll*e*w%mod){
                int x=a[j+p],y=1ll*a[j+p+i]*e%mod;
                a[j+p]=MOD(x+y);a[j+p+i]=MOD(x-y+mod);
            }
        }
    }
}
inline void Solve(int m){
    if(m==1){f[1]=1;return;}
    int M=m>>1;Solve(M);
    for(register int i=0;i<=M;i++)a[i]=1ll*f[i]*fac[i]%mod;
    now=1;
    for(register int i=0;i<=M;i++)
        b[i]=1ll*now*inv[i]%mod,now=1ll*now*M%mod;
    reverse(b,b+M+1);
    k=1;while(k<=M+M)k<<=1;INV=fas(k,mod-2);
    for(register int i=M+1;i<k;i++)a[i]=b[i]=0;
    NTT(a,1);NTT(b,1);
    for(register int i=0;i<k;i++)a[i]=1ll*a[i]*b[i]%mod;
    NTT(a,-1);
    for(register int i=0;i<k;i++)a[i]=1ll*a[i]*INV%mod;
    for(register int i=0;i<=M;i++)g[i]=1ll*inv[i]*a[M+i]%mod;
    for(register int i=M+1;i<k;i++)g[i]=0;
    NTT(f,1);NTT(g,1);
    for(register int i=0;i<k;i++)f[i]=1ll*f[i]*g[i]%mod;
    NTT(f,-1);
    for(register int i=0;i<=(M<<1);i++)f[i]=1ll*f[i]*INV%mod;
    if((M<<1)!=m){
        for(register int i=m;i;i--)
            f[i]=MOD(f[i-1]+1ll*(m-1)*f[i]%mod);
        f[0]=1ll*f[0]*(m-1)%mod;
    }
}
int main(){
    scanf("%d%d%d",&n,&A,&B);
    if(A+B-2<0||A-1<0||B-1<0||A+B-2>n-1){puts("0");return 0;}
    Preprocess();
    if(n==1){if(A+B-2==0)printf("%d\n",C(A+B-2,A-1));else puts("0");return 0;}
    Solve(n-1);
    printf("%lld\n",1ll*f[A+B-2]*C(A+B-2,A-1)%mod);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/ForwardFuture/p/11522786.html
今日推荐