Divide by Zero 2018 and Codeforces Round #474 G. Bandit Blues DP+斯特林数+分治FFT

版权声明:xgc原创文章,未经允许不得转载。 https://blog.csdn.net/xgc_woker/article/details/82926986

Description
给你三个正整数 n,a,b,定义A为一个排列中是前缀最大值的数的个数,定义B为一个排列中是后缀最大值的数的个数,求长度为n的排列中满足A = a且B = b的排列个数。


Sample Input
1 1 1


Sample Output
1


考虑DP,设f[i][j]为前i位有j个不同前缀最大值方案数。
我们从大到小插数,对于当前这个数他只有放在第一位才可能有新的前缀最大值,可得转移:
f [ i ] [ j ] = f [ i 1 ] [ j 1 ] + ( i 1 ) f [ i 1 ] [ j ] f[i][j]=f[i-1][j-1]+(i-1)f[i-1][j]
这个玩意是第一类斯特林数。。。
然后我们考虑以n为分界点,其实就相当于有a+b-2个前缀最大值,然后你选a-1个数放到左边,b-1个数放到右边,可得答案为:
f [ n 1 ] [ a + b 2 ] C ( a + b 2 , a 1 ) f[n-1][a+b-2]*C(a+b-2,a-1)
对于第一类斯特林数的求法,
f[n][m]就等于x的n次上升幂的第m项系数。
用分治FFT即可解决(涨姿势)


#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;
typedef long long LL;
const LL mod = 998244353;
int _min(int x, int y) {return x < y ? x : y;}
int _max(int x, int y) {return x > y ? x : y;}
int read() {
	int s = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
	while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
	return s * f;
}

int R[410000];
LL A[410000], jc[410000];

LL pow_mod(LL a, LL k) {
	LL ans = 1;
	while(k) {
		if(k & 1) (ans *= a) %= mod;
		(a *= a) %= mod; k /= 2;
	} return ans;
}

void NTT(LL y[], int len, int on) {
	for(int i = 0; i < len; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) * (len >> 1));
	for(int i = 0; i < len; i++) if(i < R[i]) swap(y[i], y[R[i]]);
	for(int i = 1; i < len; i *= 2) {
		LL wn = pow_mod(3, (LL)(mod - 1) / (i * 2)); if(on == -1) wn = pow_mod(wn, mod - 2);
		for(int j = 0; j < len; j += i * 2) {
			LL w = 1;
			for(int k = 0; k < i; k++) {
				LL u = y[j + k], v = y[j + k + i] * w % mod;
				y[j + k] = (u + v) % mod, y[j + k + i] = (u - v + mod) % mod;
				w = w * wn % mod;
			}
		}
	} if(on == -1) {
		LL tmp = pow_mod(len, mod - 2);
		for(int i = 0; i < len; i++) y[i] = y[i] * tmp % mod;
	}
}

void solve(LL *a, LL ln, int l, int r) {
	if(l == r) {a[0] = l - 1, a[1] = 1; return ;}
	int mid = (l + r) / 2;
	LL g1[ln + 10], g2[ln + 10];
	memset(g1, 0, sizeof(g1)), memset(g2, 0, sizeof(g2));
	solve(g1, ln / 2, l, mid), solve(g2, ln / 2, mid + 1, r);
	NTT(g1, ln, 1), NTT(g2, ln, 1);
	for(int i = 0; i < ln; i++) a[i] = g1[i] * g2[i] % mod;
	NTT(a, ln, -1);
}

LL C(int n, int m) {
	LL ans = 1;
	jc[0] = 1;
	for(int i = 1; i <= n; i++) jc[i] = jc[i - 1] * i % mod;
	return jc[n] * pow_mod(jc[m], mod - 2) % mod * pow_mod(jc[n - m], mod - 2) % mod;
}

int main() {
	int n = read(), a = read(), b = read();
	if(a + b - 1 > n || !a || !b) {puts("0"); return 0;}
	if(n == 1) {puts("1"); return 0;}
	LL c = C(a + b - 2, a - 1);
	int ln;
	for(ln = 1; ln <= 2 * (n + 1); ln *= 2);
	solve(A, ln, 1, n - 1);
	printf("%lld\n", A[a + b - 2] * c % mod);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/xgc_woker/article/details/82926986
今日推荐