[SOJ617] 小\omega的数学题

娃娃机毒瘤 qwq

题意简述:求所有长度为\(n\)的排列的逆序对个数的\(k\)次方和。注意我们认为\(0^0=1\)\(n\leq 10^7, k\leq 1000\)


  • 最朴素的想法就是直接dp答案:设\(f_{i, j}\)为所有长度为\(i\)的排列,逆序对个数的\(j\)次方和,则\(f\)
#include <cstdio>
#include <cctype>
#include <cstring>
#include <cassert>
#include <iostream>
#include <algorithm>
#define R register
#define ll long long
using namespace std;
const int N = 1e7 + 10000, M = 1 << 13, mod = 998244353; 

int n, k, rev[M];
ll f[M], g[M], fac[N], inv[N], strl[M], y[M];

template <class T> inline void read(T &x) {
    x = 0;
    char ch = getchar(), w = 0;
    while (!isdigit(ch)) w = (ch == '-'), ch = getchar();
    while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
    x = w ? -x : x;
    return;
}

inline ll quickpow(ll base, ll pw) {
    ll ret = 1;
    while (pw) {
        if (pw & 1) ret = ret * base % mod;
        base = base * base % mod, pw >>= 1;
    }
    return ret;
}

inline ll comb(int n, int m) {
    if (n < m || m < 0) return 0;
    return fac[n] * inv[m] % mod * inv[n - m] % mod;
}

inline void getRev(int lim) {
    static int lst = 0;
    if (lim == lst) return;
    lst = lim;
    int n = 1 << lim;
    for (R int i = 1; i < n; ++i)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lim >> 1);
    return;
}

inline ll addMod(ll a, ll b) {
    return (a += b) >= mod ? a - mod : a;
}

inline void ntt(ll *a, int lim, int inv) {
    getRev(lim);
    int n = 1 << lim;
    for (R int i = 1; i < n; ++i)
        if (rev[i] > i)
            swap(a[rev[i]], a[i]);
    for (R int l = 2, m; (m = l >> 1) < n; l <<= 1) {
        ll g = quickpow(3, (mod - 1) / l);
        if (inv) g = quickpow(g, mod - 2);
        for (R int i = 0; i < n; i += l) {
            ll w = 1, t;
            for (R int j = i; j < i + m; ++j, w = w * g % mod)
                t = a[j + m] * w % mod, a[j + m] = addMod(a[j], mod - t), a[j] = addMod(a[j], t);
        }
    }
    if (inv) {
        inv = quickpow(n, mod - 2);
        for (R int i = 0; i < n; ++i)
            a[i] = a[i] * inv % mod;
    }
    return;
}

inline ll lagrange(int lim, int x) {
    ll prod = fac[x - 1] * inv[x - 1 - lim] % mod, ret = 0, w;
    for (R int i = 1; i <= lim; ++i) {
        w = prod * quickpow(x - i, mod - 2) % mod * inv[i - 1] % mod * inv[lim - i] % mod * y[i] % mod;
        ret = addMod(ret, (lim - i) & 1 ? mod - w : w);
    }
    return ret;
}

int main() {
    read(n), read(k);
    int lim = max(n, k), len = 1;
    fac[0] = strl[0] = f[0] = 1;
    for (R int i = 1; i <= lim; ++i)
        fac[i] = fac[i - 1] * i % mod;
    if (k == 0) return printf("%lld\n", fac[n]), 0;
    inv[lim] = quickpow(fac[lim], mod - 2);
    for (R int i = lim - 1; ~i; --i)
        inv[i] = inv[i + 1] * (i + 1) % mod;
    for (R int i = 1; i <= k; ++i) {
        for (R int j = i; j; --j)
            strl[j] = (strl[j - 1] + strl[j] * j) % mod;
        strl[0] = 0;
    }
    while ((1 << len) <= (k + 1) * 3) ++len;
    lim = min(n, 2 * k + 1);
    for (R int i = 2; i <= lim; ++i) {
        for (R int j = 0; j < (1 << len); ++j)
            g[j] = comb(i, j + 1);
        ntt(g, len, 0), ntt(f, len, 0);
        for (R int j = 0; j < (1 << len); ++j)
            f[j] = f[j] * g[j] % mod;
        ntt(f, len, 1);
        for (R int j = 0; j <= k; ++j)
            y[i] = (y[i] + f[j] * fac[j] % mod * strl[j]) % mod;
        for (R int j = k + 1; j < (1 << len); ++j)
            f[j] = 0;
    }
    if (lim == n) return printf("%lld\n", y[n]), 0;
    for (R int i = 1; i <= lim; ++i)
        y[i] = y[i] * inv[i] % mod;
    printf("%lld\n", lagrange(lim, n) * fac[n] % mod);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/suwakow/p/11656921.html