题目大意:
求有多少个n的排列:
满足:
1.有a个数比它左边的都大
2.有b个数比它右边的都大
1 <= n <= 10^5
题解:
设 表示i的排列,有j个数比它左边的都大的方案数。
考虑加一个最小的数进来,转移为
发现这就是无符号第一类斯特兰数。
怎么同时满足条件a,b呢?
枚举最大的数的位置,显然a个数都在这个数的左边,b个数都在这个数的右边。
这个瞎合并成:
分治NTT去求 。
Code:
#include<cstdio>
#include<algorithm>
#define ll long long
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define fd(i, x, y) for(int i = x; i >= y; i --)
#define ff(i, x, y) for(int i = x; i < y; i ++)
using namespace std;
const int mo = 998244353;
int n, A, B;
ll ksm(ll x, ll y) {
ll s = 1;
for(; y; y /= 2, x = x * x % mo)
if(y & 1) s = s * x % mo;
return s;
}
ll fac(int n) {
ll s = 1; fo(i, 1, n) s = s * i % mo;
return s;
}
ll C(int n, int m) {
return fac(n) * ksm(fac(m) * fac(n - m) % mo, mo - 2) % mo;
}
const int N = 8e5 + 5;
ll d[N]; int p[N], q[N];
ll w[N], a[N], b[N]; int tp;
void dft(ll *a, int n) {
ff(i, 0, n) {
int p = i, q = 0;
fo(j, 1, tp) q = q * 2 + p % 2, p /= 2;
if(i > q) swap(a[i], a[q]);
}
for(int m = 2; m <= n; m *= 2) {
int h = m / 2;
ff(i, 0, h) {
ll W = w[i * (n / m)];
for(int j = i; j < n; j += m) {
int k = j + h;
ll u = a[j], v = a[k] * W % mo;
a[j] = (u + v) % mo; a[k] = (u - v + mo) % mo;
}
}
}
}
void fft(ll *a, ll *b, int n) {
ll v = ksm(3, (mo - 1) / n);
w[0] = 1; ff(i, 1, n) w[i] = w[i - 1] * v % mo ;
dft(a, n); dft(b, n);
ff(i, 0, n) a[i] = a[i] * b[i] % mo;
ff(i, 1, n / 2) swap(w[i], w[n - i]);
dft(a, n); v = ksm(n, mo - 2);
ff(i, 0, n) a[i] = a[i] * v % mo;
}
void dg(int x, int y) {
if(x == y) return;
int m = x + y >> 1;
dg(x, m); dg(m + 1, y);
int n0 = q[x] - p[x] + q[m + 1] - p[m + 1];
tp = 0; while(1 << ++ tp <= n0); n0 = 1 << tp;
ff(i, 0, n0) a[i] = b[i] = 0;
fo(i, p[x], q[x]) a[i - p[x]] = d[i];
fo(i, p[m + 1], q[m + 1]) b[i - p[m + 1]] = d[i];
fft(a, b, n0);
q[x] += q[m + 1] - p[m + 1];
fo(i, p[x], q[x]) d[i] = a[i - p[x]];
}
ll solve(int n, int m) {
if(n < m) return 0;
if(n == 0 && m == 0) return 1;
fo(i, 0, n - 1) {
p[i] = ++ tp; q[i] = ++ tp;
d[p[i]] = i; d[q[i]] = 1;
}
dg(0, n - 1);
return d[p[0] + m];
}
int main() {
scanf("%d %d %d", &n, &A, &B);
if(A == 0 || B == 0) {
printf("0\n"); return 0;
}
printf("%I64d", solve(n - 1, A + B - 2) * C(A + B - 2, A - 1) % mo);
}