【BZOJ4555】求和(TJOI&HEOI2016)-第二类斯特林数+NTT

测试地址:求和
做法:本题需要用到第二类斯特林数+NTT。
从题目中给的递推式或者根据组合数学的知识,第二类斯特林数 S ( i , j ) 的组合意义是:将 i 个有区别的球放入 j 个无区别的盒子的方案数。由此我们可以得到通项公式:
S ( i , j ) = 1 j ! k = 0 j ( 1 ) k C j k ( j k ) i
这其实就是一个容斥的形式,相当于枚举强制哪些盒子是空的,至于要乘一个 1 j ! 是因为盒子无区别,而里面算的方案是有区别的。
那么将这个式子代入题目要求的式子,有:
f ( n ) = i = 0 n j = 0 i 2 j k = 0 j ( 1 ) k C j k ( j k ) i
将组合数拆开,整理得:
f ( n ) = i = 0 n j = 0 i 2 j j ! k = 0 j ( 1 ) k k ! ( j k ) i ( j k ) !
我们发现后半部分已经很像一个卷积的形式了,但是因为它还和 i 有关,所以我们想办法把 i 换进去。
我们知道当 j > i S ( i , j ) = 0 ,所以上式中 j 的上限可以换成 n ,那么就可以把 i 换进去,得到:
f ( n ) = j = 0 n 2 j j ! k = 0 j ( 1 ) k k ! i = 0 n ( j k ) i ( j k ) !
那么这个式子的后半部分就是函数 g ( x ) = ( 1 ) x x ! 和函数 h ( x ) = i = 0 n x i x ! 的卷积了,可以用NTT求出,而求 h ( x ) 时,我们发现它是一个等比数列的前缀和,直接用等比数列求和公式求即可。特别地, h ( 0 ) = 1 , h ( 1 ) = n + 1 ,直接用公式算的话这两个会算错。
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const ll g=3;
ll n,fac[100010],inv[100010],invfac[100010];
ll a[1000010]={0},b[1000010]={0};
int r[1000010];

ll power(ll a,ll b)
{
    ll s=1,ss=a;
    while(b)
    {
        if (b&1) s=s*ss%mod;
        ss=ss*ss%mod;b>>=1;
    }
    return s;
}

void NTT(ll *a,ll type,int n)
{
    for(int i=0;i<n;i++)
        if (i<r[i]) swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1)
    {
        ll W=power(g,(mod-1)/(mid<<1));
        if (type==-1) W=power(W,mod-2);
        for(int l=0;l<n;l+=(mid<<1))
        {
            ll w=1;
            for(int k=0;k<mid;k++,w=w*W%mod)
            {
                ll x=a[l+k],y=w*a[l+mid+k]%mod;
                a[l+k]=(x+y)%mod;
                a[l+mid+k]=(x-y+mod)%mod;
            }
        }
    }
    if (type==-1)
    {
        ll inv=power(n,mod-2);
        for(int i=0;i<n;i++)
            a[i]=a[i]*inv%mod;
    }
}

int main()
{
    scanf("%lld",&n);

    fac[0]=fac[1]=inv[1]=invfac[0]=invfac[1]=1;
    for(ll i=2;i<=n;i++)
    {
        fac[i]=fac[i-1]*i%mod;
        inv[i]=(mod-mod/i)*inv[mod%i]%mod;
        invfac[i]=invfac[i-1]*inv[i]%mod;
    }

    for(ll i=0;i<=n;i++)
    {
        a[i]=(((i%2)?-1:1)*invfac[i]+mod)%mod;
        if (i==0) b[i]=1;
        if (i==1) b[i]=n+1;
        if (i>1) b[i]=(power(i,n+1)-1+mod)*invfac[i]%mod*inv[i-1]%mod;
    }
    int x=1,bit=0;
    while(x<=(n<<2)) x<<=1,bit++;
    r[0]=0;
    for(int i=1;i<x;i++)
        r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));
    NTT(a,1,x),NTT(b,1,x);
    for(int i=0;i<x;i++)
        a[i]=a[i]*b[i]%mod;
    NTT(a,-1,x);

    ll ans=0;
    for(ll i=0,j=1;i<=n;i++,j=j*2ll%mod)
    {
        ll tmp=j*fac[i]%mod;
        tmp=tmp*a[i]%mod;
        ans=(ans+tmp)%mod;
    }
    printf("%lld",ans);

    return 0;
}

猜你喜欢

转载自blog.csdn.net/maxwei_wzj/article/details/80156015