[HEOI2016/TJOI2016]求和(第二类斯特林数+NTT)

Address

LuoguP4091

Solution

  • \[ans=\sum_{i=0}^{n}\sum_{j=0}^{i}S(i,j)*2^j*(j!)\]
  • 因为\(i>j\) 时,\(S(i,j)=0\),所以:
    \[ans=\sum_{i=0}^{n}\sum_{j=0}^{n}S(i,j)*2^j*(j!)\]
  • 众所周知 :
    \[S(i,j)=\frac{1}{j!}\sum_{k=0}^{j}(-1)^k*(j-k)^i*C_j^k\]
    因此:
    \[ans=\sum_{i=0}^{n}\sum_{j=0}^{n}\sum_{k=0}^{j}(-1)^k*(j-k)^i*C_j^k*2^j\]
  • 发现 \(2^j\) 只包含了变量 \(j\),所以把它提到前面:
    \[ans=\sum_{j=0}^{n}2^j*\sum_{i=0}^{n}\sum_{k=0}^{j}(-1)^k*(j-k)^i*C_j^k\]

  • 然后把 \(C_j^k\) 拆成阶乘形式,再整理得:
    \[ans=\sum_{j=0}^{n}2^j*(j!)*\sum_{k=0}^j*\frac{(-1)^k}{k!}*\frac{\sum_{i=0}^{n}(j-k)^i}{(j-k)!}\]

  • 于是令 \(f(i)=\frac{(-1)^i}{i!},g(j)=\frac{\sum_{i=0}^nj^i}{j!}\)
  • 显然 \(g(j)\) 可以用等比数列求和公式变成:
    \[\frac{j^{n+1}-1}{j!(j-1)}\]

  • 那么用 \(NTT\)\(f\)\(g\) 乘起来就行了。

Code

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

const int e = 1e6 + 5, mod = 998244353;
int a[e], lim = 1, rev[e], b[e], n, ans, fa[e], g[e], cc[e], dd[e];

inline int ksm(int x, int y)
{
    int res = 1;
    while (y)
    {
        if (y & 1) res = 1ll * res * x % mod;
        y >>= 1;
        x = 1ll * x * x % mod;
    }
    return res;
}

inline void fft(int n, int *a, int op)
{
    int i, j, k, r = (op == 1 ? 3 : 998244354 / 3);
    for (i = 0; i < n; i++) 
    if (i < rev[i]) swap(a[i], a[rev[i]]);
    for (k = 1; k < n; k <<= 1)
    {
        int w0 = ksm(r, (mod - 1) / (k << 1));
        for (i = 0; i < n; i += (k << 1))
        {
            int w = 1;
            for (j = 0; j < k; j++)
            {
                int b = a[i + j], c = 1ll * w * a[i + j + k] % mod;
                a[i + j] = (b + c) % mod;
                a[i + j + k] = (b - c + mod) % mod;
                w = 1ll * w * w0 % mod;
            }
        }
    }
}

int main()
{
    cin >> n;
    int i, k = 0, fac = 1;
    while (lim < n * 2)
    {
        lim <<= 1;
        k++;
    }
    for (i = 1; i < lim; i++)
    rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << k - 1);
    for (i = 0; i <= n; i++)
    {
        if (i != 0) fac = 1ll * fac * i % mod;
        
        if (i & 1) a[i] = mod - 1;
        else a[i] = 1;
        a[i] = 1ll * a[i] * ksm(fac, mod - 2) % mod;
        
        if (i == 0) b[i] = 1;
        else if (i == 1) b[i] = n + 1; 
        else
        b[i] = 1ll * (ksm(i, n + 1) + mod - 1) % mod * ksm(i - 1, mod - 2) % mod
        * ksm(fac % mod, mod - 2) % mod;
    int j;
    fft(lim, a, 1);
    fft(lim, b, 1);
    for (i = 0; i < lim; i++) a[i] = 1ll * a[i] * b[i] % mod;
    fft(lim, a, -1);
    for (i = 0; i < lim; i++) a[i] = 1ll * a[i] * ksm(lim, mod - 2) % mod;
    int p = 1;
    fac = 1;
    for (i = 0; i <= n; i++)
    {
        if (i != 0) fac = 1ll * fac * i % mod;
        int c = a[i];
        ans = (ans + 1ll * c * fac % mod * p) % mod;
        p = 2ll * p % mod;
    }
    cout << ans << endl;
    fclose(stdin);
    fclose(stdout);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/cyf32768/p/12196181.html