Codeforce 622 F. The Sum of the k-th Powers(拉格朗日插值求k次幂之和,拉格朗日插值公式)

在这里插入图片描述

题目大意:求 i = 1 n i k \displaystyle\sum_{i = 1}^ni^k
求k次幂有多种求法,例如:
伯努利数求k次幂之和(待补)
斯特林数求k次幂之和
拉格朗日插值法求k次幂之和

这里采用拉格朗日插值法进行求解。
拉格朗日可以通过 k + 1 k + 1 个点唯一确定一个 k k 次多项式,它的公式为: f ( x ) = i = 1 n y [ i ] i j x x [ j ] x [ i ] x [ j ] f(x) = \sum_{i = 1}^ny[i] \prod_{i \neq j}\frac{x - x[j]}{x[i]-x[j]}
其中 x [ i ] , y [ i ] x[i],y[i] 对应已知的点值,对已知的点很容易通过代入验证正确性,带入 x [ i ] x[i] 将会得到 y [ i ] y[i]

这个式子在一般情况下的复杂度为 O ( n 2 ) O(n^2) ,比高斯消元的 n 3 n^3 更加优秀,在已知点的 x x 取值连续的情况下,复杂度能降低到 O ( n ) O(n) ,只要预处理阶乘逆元,以及 x x 的 k + 1 项倒阶乘: x f a c xfac
f ( x ) = i = 1 n y [ i ] x f a c f a c [ i ] f a c [ n i ] ( x i ) f(x)=\sum_{i = 1}^ny[i]*\frac{xfac}{fac[i]*fac[n - i]*(x-i)}

为什么这题可以用拉格朗日插值
当然是因为 i = 1 n i k \displaystyle\sum_{i = 1}^ni^k 是一个以n为自变量的多项式,并且是 k + 1 k + 1 次多项式
证明:
S ( n , k ) = i = 1 n i k \displaystyle S(n,k)=\sum_{i = 1}^ni^k
对这个序列两两差分可以得到: ( n + 1 ) k + 1 n k + 1 = i = 0 k + 1 C ( k + 1 , i ) n i n k + 1 = i = 0 k C ( k + 1 , i ) n i (n + 1)^{k+1} - n^{k+1}=\sum_{i = 0}^{k+1}C(k+1,i)*n^i - n^{k+1}=\sum_{i = 0}^kC(k+1,i)*n^i n k + 1 ( n 1 ) k + 1 = i = 0 k C ( k + 1 , i ) ( n 1 ) i n^{k+1} - (n-1)^{k+1}=\sum_{i = 0}^{k}C(k+1,i)*(n-1)^i . . . ... 1 k + 1 0 k + 1 = i = 0 k C ( k + 1 , i ) 0 i 1^{k+1}-0^{k+1}=\sum_{i = 0}^{k}C(k+1,i)0^i

逐项求和可以得到 ( n + 1 ) k + 1 = i = 0 k C ( k + 1 , i ) S ( n , k ) \displaystyle (n+1)^{k+1} =\sum_{i=0}^kC(k+1,i)*S(n,k) ,即得到 S ( n , k ) S(n,k) 是以 n n 为自变量的 k + 1 k + 1 次多项式

f ( x ) = x k f(x) = x^k ,可以得到一个更一般的推广结论: k k 次多项式的前 n n 项和 g ( n ) g(n) 是一个以 n n 为自变量的 k + 1 k + 1 次多项式

回到这题,前 k + 2 k + 2 项可以 k log k k \log k 暴力计算,对 n k + 2 n \leq k + 2 直接输出答案,对 n > k + 2 n > k + 2 只要插值一下,根据插值公式计算即可。


代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
const int maxn = 1e6 + 100;
int n,k;
int x[maxn],y[maxn];				//拉格朗日差值的计算 
int fac[maxn],ifac[maxn];			//阶乘的逆元 
inline int add(int x, int y) {
  	x += y;
  	if (x >= mod)
    	x -= mod;
  	return x;
}

inline int sub(int x, int y) {
  	x -= y;
  	if (x < 0)
    	x += mod;
  	return x;
}

inline int mul(int x, int y) {
  	return (long long) x * y % mod;
}
int fpow(int a,int b) {
	int r = 1;
	while (b) {
		if (b & 1) r = mul(r,a);
		b >>= 1;
		a = mul(a,a);
	}
	return r;
}
int main() {
	scanf("%d%d",&n,&k);
	for (int i = 1; i <= k + 2; i++) {			//暴力计算 k + 2 个点,根据这 k + 2个点就可以通过插值唯一确定 k + 1次多项式 
		x[i] = i;
		y[i] = add(y[i - 1],fpow(x[i],k));
	}
	if (n <= k + 2) {							//n <= k + 2就直接输出,否则下面的处理会出错 
		printf("%d\n",y[n]);
		return 0;
	}
	fac[0] = 1;
	for (int i = 1; i <= k + 2; i++) {			//由于k+2个点x取值连续,预处理阶乘,使复杂度降低到O(k) 
		fac[i] = mul(fac[i - 1],i);
	}
	ifac[k + 2] = fpow(fac[k + 2],mod - 2);
	for (int i = k + 1; i >= 0; i--) {
		ifac[i] = mul(ifac[i + 1],i + 1);
	}
	int tmp = 1;								//n的倒阶乘 ,同样也是为了加速 
	for (int i = 1; i <= k + 2; i++) {
		tmp = mul(tmp,(n - i) % mod);
	}
	int ans = 0;
	for (int i = 1; i <= k + 2; i++) {			//插值迭代,得到 f(n) 
		int t = k + 2 - i;
		int p = (t & 1) ? -1 : 1;
		int inv = fpow((n - i) % mod,mod - 2);
		int res = mul(mul(ifac[i - 1],ifac[t]),mul(tmp,inv));
		res = mul(mul(res,p),y[i]);
		if (res < 0) res += mod;
		ans = add(ans,res);
	}
	printf("%d\n",ans);
	return 0;
}
发布了332 篇原创文章 · 获赞 10 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_41997978/article/details/104237443