【BZOJ4555】[Tjoi2016&Heoi2016]求和【斯特林数】【FFT/NTT】

题解:
先推一波公式。
这是一个容斥原理的式子: S ( n , m ) = 1 m ! k = 0 m ( 1 ) k C m k ( m k ) n
就是枚举有多少个盒子是空的,容斥一下。由于盒子是一样的,所以最后要除以m!。
=> S ( n , m ) = 1 m ! k = 0 m ( 1 ) k m ! k ! ( m k ) ! ( m k ) n
=> S ( n , m ) = k = 0 m ( 1 ) k 1 k ! ( m k ) ! ( m k ) n
=> S ( n , m ) = k = 0 m ( 1 ) k k ! ( m k ) n ( m k ) !
我们再看过来要求的式子。
f ( n ) = i = 0 n j = 0 i S ( i , j ) 2 j j !
=> f ( n ) = i = 0 n j = 0 n S ( i , j ) 2 j j !
=> f ( n ) = j = 0 n 2 j j ! i = 0 n S ( i , j )
=> f ( n ) = j = 0 n 2 j j ! i = 0 n k = 0 j ( 1 ) k k ! ( j k ) i ( j k ) !
=> f ( n ) = j = 0 n 2 j j ! k = 0 j ( 1 ) k k ! i = 0 n ( j k ) i ( j k ) !
于是我们可以预处理 k = 0 j ( 1 ) k k ! i = 0 n ( j k ) i ( j k ) ! 的值,然后就可以NTT卷积求解啦!
代码

#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=270005;
const ll mod=998244353;
int n,m,rev[N];
ll ans,jc[N],a[N],b[N];
ll fastpow(ll a,ll x){
    ll res=1;
    while(x){
        if(x&1){
            res=res*a%mod;
        }
        x>>=1;
        a=a*a%mod;
    }
    return res;
}
ll getinv(ll x){
    return fastpow(x,mod-2);
}
void ntt(ll *a,int dft){
    for(int i=0;i<n;i++){
        if(i<rev[i]){
            swap(a[i],a[rev[i]]);
        }
    }
    for(int i=1;i<n;i<<=1){
        ll wn=fastpow(3,(mod-1)/i/2);
        if(dft==-1){
            wn=getinv(wn);
        }
        for(int j=0;j<n;j+=i<<1){
            ll w=1,x,y;
            for(int k=j;k<j+i;k++,w=w*wn%mod){
                x=a[k];
                y=w*a[k+i]%mod;
                a[k]=(x+y)%mod;
                a[k+i]=(x-y+mod)%mod;
            }
        }
    }
    if(dft==-1){
        ll inv=getinv(n);
        for(int i=0;i<n;i++){
            a[i]=a[i]*inv%mod;
        }
    }
}
int main(){
    scanf("%d",&m);
    jc[0]=1;
    for(int i=1;i<=m;i++){
        jc[i]=jc[i-1]*i%mod;
    }
    for(int i=0;i<=m;i++){
        a[i]=(fastpow(-1,i)*getinv(jc[i])+mod)%mod;
    }
    b[0]=1;
    b[1]=m+1;
    for(int i=2;i<=m;i++){
        b[i]=(fastpow(i,m+1)-1)*getinv(i-1)%mod*getinv(jc[i])%mod;
    }
    for(n=1;n<=m*2;n<<=1);
    for(int i=0;i<n;i++){
        rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
    }
    ntt(a,1);
    ntt(b,1);
    for(int i=0;i<n;i++){
        a[i]=a[i]*b[i]%mod;
    }
    ntt(a,-1);
    for(int i=0;i<=m;i++){
        ans=(ans+fastpow(2,i)*jc[i]%mod*a[i]%mod)%mod;
    }
    printf("%lld\n",ans);
    return 0;

猜你喜欢

转载自blog.csdn.net/ez_2016gdgzoi471/article/details/80204588