娃娃机毒瘤 qwq
题意简述:求所有长度为\(n\)的排列的逆序对个数的\(k\)次方和。注意我们认为\(0^0=1\)。\(n\leq 10^7, k\leq 1000\)。
- 最朴素的想法就是直接dp答案:设\(f_{i, j}\)为所有长度为\(i\)的排列,逆序对个数的\(j\)次方和,则\(f\)
#include <cstdio>
#include <cctype>
#include <cstring>
#include <cassert>
#include <iostream>
#include <algorithm>
#define R register
#define ll long long
using namespace std;
const int N = 1e7 + 10000, M = 1 << 13, mod = 998244353;
int n, k, rev[M];
ll f[M], g[M], fac[N], inv[N], strl[M], y[M];
template <class T> inline void read(T &x) {
x = 0;
char ch = getchar(), w = 0;
while (!isdigit(ch)) w = (ch == '-'), ch = getchar();
while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
x = w ? -x : x;
return;
}
inline ll quickpow(ll base, ll pw) {
ll ret = 1;
while (pw) {
if (pw & 1) ret = ret * base % mod;
base = base * base % mod, pw >>= 1;
}
return ret;
}
inline ll comb(int n, int m) {
if (n < m || m < 0) return 0;
return fac[n] * inv[m] % mod * inv[n - m] % mod;
}
inline void getRev(int lim) {
static int lst = 0;
if (lim == lst) return;
lst = lim;
int n = 1 << lim;
for (R int i = 1; i < n; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lim >> 1);
return;
}
inline ll addMod(ll a, ll b) {
return (a += b) >= mod ? a - mod : a;
}
inline void ntt(ll *a, int lim, int inv) {
getRev(lim);
int n = 1 << lim;
for (R int i = 1; i < n; ++i)
if (rev[i] > i)
swap(a[rev[i]], a[i]);
for (R int l = 2, m; (m = l >> 1) < n; l <<= 1) {
ll g = quickpow(3, (mod - 1) / l);
if (inv) g = quickpow(g, mod - 2);
for (R int i = 0; i < n; i += l) {
ll w = 1, t;
for (R int j = i; j < i + m; ++j, w = w * g % mod)
t = a[j + m] * w % mod, a[j + m] = addMod(a[j], mod - t), a[j] = addMod(a[j], t);
}
}
if (inv) {
inv = quickpow(n, mod - 2);
for (R int i = 0; i < n; ++i)
a[i] = a[i] * inv % mod;
}
return;
}
inline ll lagrange(int lim, int x) {
ll prod = fac[x - 1] * inv[x - 1 - lim] % mod, ret = 0, w;
for (R int i = 1; i <= lim; ++i) {
w = prod * quickpow(x - i, mod - 2) % mod * inv[i - 1] % mod * inv[lim - i] % mod * y[i] % mod;
ret = addMod(ret, (lim - i) & 1 ? mod - w : w);
}
return ret;
}
int main() {
read(n), read(k);
int lim = max(n, k), len = 1;
fac[0] = strl[0] = f[0] = 1;
for (R int i = 1; i <= lim; ++i)
fac[i] = fac[i - 1] * i % mod;
if (k == 0) return printf("%lld\n", fac[n]), 0;
inv[lim] = quickpow(fac[lim], mod - 2);
for (R int i = lim - 1; ~i; --i)
inv[i] = inv[i + 1] * (i + 1) % mod;
for (R int i = 1; i <= k; ++i) {
for (R int j = i; j; --j)
strl[j] = (strl[j - 1] + strl[j] * j) % mod;
strl[0] = 0;
}
while ((1 << len) <= (k + 1) * 3) ++len;
lim = min(n, 2 * k + 1);
for (R int i = 2; i <= lim; ++i) {
for (R int j = 0; j < (1 << len); ++j)
g[j] = comb(i, j + 1);
ntt(g, len, 0), ntt(f, len, 0);
for (R int j = 0; j < (1 << len); ++j)
f[j] = f[j] * g[j] % mod;
ntt(f, len, 1);
for (R int j = 0; j <= k; ++j)
y[i] = (y[i] + f[j] * fac[j] % mod * strl[j]) % mod;
for (R int j = k + 1; j < (1 << len); ++j)
f[j] = 0;
}
if (lim == n) return printf("%lld\n", y[n]), 0;
for (R int i = 1; i <= lim; ++i)
y[i] = y[i] * inv[i] % mod;
printf("%lld\n", lagrange(lim, n) * fac[n] % mod);
return 0;
}