bzoj 4555 [Tjoi2016&Heoi2016]求和 (NTT)

Description

在2016年,佳媛姐姐刚刚学习了第二类斯特林数,非常开心。

现在他想计算这样一个函数的值:

S(i, j)表示第二类斯特林数,递推公式为:

S(i, j) = j ∗ S(i − 1, j) + S(i − 1, j − 1), 1 <= j <= i − 1。

边界条件为:S(i, i) = 1(0 <= i), S(i, 0) = 0(1 <= i)

你能帮帮他吗?

Input

输入只有一个正整数

Output

 输出f(n)。由于结果会很大,输出f(n)对998244353(7 × 17 × 223 + 1)取模的结果即可。1 ≤ n ≤ 100000

Sample Input

3

Sample Output

87

思路:https://www.cnblogs.com/Skyminer/p/6402254.html

反正我没推出来。。。借过公式来直接做的。心累。。。

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int MAXN=600010;
const ll mod=998244353;
ll qmod(ll x,ll p)
{
    ll ans=1;
    while(p)
    {
        if(p&1) ans=ans*x%mod;
        p>>=1;
        x=x*x%mod;
    }
    return ans;
}
const ll pri_rt = 3;
ll w[MAXN];
void NTT(ll *x,int n,int p){
    for(int i=0,t=0;i<n;++i){
        if(i > t) swap(x[i],x[t]);
        for(int j=n>>1;(t^=j)<j;j>>=1);
    }
    for(int m=2;m<=n;m<<=1){
        int k = m>>1;
        int wn = qmod(pri_rt,p == 1 ? (mod-1)/m : (mod-1) - (mod-1)/m);
        for(int i=1;i<k;++i) w[i] = 1LL*w[i-1]*wn % mod;
        w[0] = 1;
        for(int i=0;i<n;i+=m){
            for(int j=0;j<k;++j){
                ll u = 1LL*x[i+j+k]*w[j] % mod;
                x[i+j+k] = x[i+j] - u;
                if(x[i+j+k] < 0) x[i+j+k] += mod;
                x[i+j] += u;
                if(x[i+j] >= mod) x[i+j] -= mod;
            }
        }
    }
    if(p == -1){
        ll inv = qmod(n,mod-2);
        for(int i=0;i<n;++i) x[i] = 1LL*x[i]*inv % mod;
    }
}
ll f[MAXN],g[MAXN];
ll sum[MAXN];
ll fac[MAXN],inv[MAXN],p[MAXN];

void init(int N)
{
    fac[0]=1;
    for(int i=1;i<=N;i++)
        fac[i]=fac[i-1]*(ll)i%mod;
    inv[N]=qmod(fac[N],mod-2);
    for(int i=N-1;i>=0;i--)
        inv[i]=1ll*(i+1)*inv[i+1]%mod;
    p[0]=1;
    for(int i=1;i<=N;i++)
        p[i]=p[i-1]*2ll%mod;
}
int n;
int main()
{
    scanf("%d",&n);
    int len=1;
    while(len<=n+1) len<<=1;
    len<<=1;
    init(n);
    for(int i=0;i<=n;i++)
        f[i]=((i%2==1?-1:1)*inv[i]+mod)%mod;
    for(int i=2;i<=n;i++)
        g[i]=(qmod(1ll*i,1ll*(n+1))-1ll*i+mod)%mod*inv[i]%mod*qmod(i-1,mod-2)%mod;
    g[1]=n;

    NTT(f,len,1);
    NTT(g,len,1);
    for(int i=0;i<len;i++)
        sum[i]=f[i]*g[i]%mod;
    NTT(sum,len,-1);

    ll ans=1;
    for(int i=1;i<=n;i++)
        ans=(ans+p[i]*fac[i]%mod*sum[i]%mod+mod)%mod;
    cout<<ans<<endl;
	return 0;
}

猜你喜欢

转载自blog.csdn.net/dllpXFire/article/details/81779360