P4463 [集训队互测2012] calc(dp + 拉格朗日插值优化)

在这里插入图片描述


考虑 dp 求解,dp 构造升序序列,最后乘上 n ! n!

f [ i ] [ j ] f[i][j] 表示前 i i 个位置,最大值小于等于 j j 的贡献,转移方程: f [ i ] [ j ] = f [ i 1 ] [ j 1 ] j + f [ i ] [ j 1 ] f[i][j] = f[i - 1][j - 1] * j + f[i][j - 1]

最终答案是 f [ n ] [ k ] f[n][k] ,k 非常大肯定无法求解,考虑优化:
g [ i ] [ j ] g[i][j] 表示前 i i 个位置,第 i i 个放 j j 的贡献,转移方程: g [ i ] [ j ] = j k = 1 j 1 g [ i 1 ] [ k ] \displaystyle g[i][j] = j*\sum_{k = 1}^{j - 1}g[i - 1][k]
显然有 f [ n ] [ k ] = i = 1 k g [ n ] [ k ] \displaystyle f[n][k] = \sum_{i = 1}^kg[n][k]

如果可以证明 g [ n ] [ k ] g[n][k] 是一个以 k k 为自变量的多项式,就可以使用拉格朗日插值快速求解

使用归纳法证明:
当 n = 0 时,g[0][k] = 0,结论成立
设 n > 0 且,g[n][k] 是一个以 k 为自变量的多项式
根据转移方程,有: g [ n + 1 ] [ k ] = k i = 1 k 1 g [ n ] [ i ] \displaystyle g[n + 1][k] =k*\sum_{i = 1}^{k - 1}g[n][i] ,根据k次幂和的推论,可以得知 g [ n + 1 ] [ k ] g[n + 1][k] 是以 k k 为自变量的多项式,当 n = 0 n = 0 d p [ n ] [ k ] dp[n][k] 是一个 0 0 次多项式, n n 每增一,根据转移方程可以得出 多项式的次数 + 2 + 2 ,因此 d p [ n ] [ k ] dp[n][k] 是一个以 k k 为自变量的 2 n 2n 次多项式。

因此 f [ n ] [ k ] f[n][k] 是一个以 k 为自变量的 2 n + 1 2n + 1 次多项式,求出 2 n + 2 2n + 2 个点,通过插值快速求出最终答案。不要忘了乘上 n ! n!


代码:

#include<bits/stdc++.h>
using namespace std;
const int maxn = 4e3 + 10;
typedef long long ll;
int mod,mx,n,k;
ll fac[maxn],ifac[maxn];
inline ll add(ll x, ll y) {
  	x += y;
  	if (x >= mod) x -= mod;
  	return x;
}

inline ll sub(ll x, ll y) {
	x -= y;
	if (x < 0) x += mod;
	return x;
}

inline ll mul(ll x, ll y) {
  	return x * y % mod;
}
ll fpow(ll a,ll b) {
	ll r = 1;
	while(b) {
		if (b & 1) r = mul(r,a);
		b >>= 1;
		a = mul(a,a);
	}
	return r;
}
ll cal(ll g[maxn],ll x) {			//拉格朗日插值计算多项式
	if (x <= mx) return g[x];
	ll tmp = 1,inv,ans = 0;
	for (int i = 1; i <= mx; i++)
		tmp = mul(tmp,x - i);
	for (int i = 1; i <= mx; i++) {
		ll res = 1, inv = fpow(x - i,mod - 2);
		res = mul(res,g[i]);
		res = mul(res,ifac[i - 1]);
		res = mul(res,ifac[mx - i]);
		res = mul(res,inv);
		res = mul(res,tmp);
		if ((mx - i) & 1) res = mul(res,-1);
		if (res < 0) res += mod;
		ans = add(ans,res);
	}
	return ans;
}
ll f[maxn],tp[maxn],dp[2000][2000];
int main() {
	scanf("%d%d%d",&k,&n,&mod);
	fac[0] = 1;
	for (int i = 1; i <= 4000; i++)
		fac[i] = mul(fac[i - 1],i);
	ifac[4000] = fpow(fac[4000],mod - 2);
	for (int i = 4000 - 1; i >= 0; i--)
		ifac[i] = mul(ifac[i + 1],i + 1);
	mx = 2 * n + 4;
	for (int j = 0; j <= mx; j++)
		dp[0][j] = 1;
	for (int i = 1; i <= n; i++)
		for (int j = 1; j <= mx; j++)
			dp[i][j] = (1ll * dp[i - 1][j - 1] * j + dp[i][j - 1]) % mod;
	printf("%lld\n",1ll * cal(dp[n],k) * fac[n] % mod);
	return 0;
}
发布了332 篇原创文章 · 获赞 10 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_41997978/article/details/104302887