UOJ269 如何优雅地求和

版权声明:写得不好,转载请通知一声,还请注明出处,感激不尽 https://blog.csdn.net/As_A_Kid/article/details/86698729

Problem

UOJ
给定 n , p n,p f ( x ) f(x) 是一个 m m 阶函数,求
(1) Q ( f ) = k = 0 n f ( k ) ( n k ) p k ( 1 p ) n k Q(f)=\sum_{k=0}^n f(k)\binom n k p^k (1-p)^{n-k}\tag1

Solution

首先 ( n m ) = 0 \binom {n} {-m}=0

我们思考简单的二项分布即当 f ( x ) = x f(x)=x 时的推导过程:

(2) Q ( f ) = k = 0 n k n ! k ! ( n k ) ! p k ( 1 p ) n k Q(f)=\sum_{k=0}^n k \frac {n!} {k!(n-k)!} p^k (1-p)^{n-k}\tag2

(3) Q ( f ) = k = 0 n n ( n 1 ) ! ( k 1 ) ! ( n k ) ! p k ( 1 p ) n k Q(f)=\sum_{k=0}^n n\frac {(n-1)!} {(k-1)!(n-k)!} p^k (1-p)^{n-k}\tag3

(4) Q ( f ) = n p k = 0 n ( n 1 k 1 ) p k 1 ( 1 p ) n k Q(f)=np\sum_{k=0}^n \binom {n-1} {k-1} p^{k-1} (1-p)^{n-k}\tag4

(5) Q ( f ) = n p k = 0 n 1 ( n 1 k ) p k ( 1 p ) n k 1 Q(f)=np\sum_{k=0}^{n-1} \binom {n-1} {k} p^k (1-p)^{n-k-1}\tag5

(6) Q ( f ) = n p ( p + ( 1 p ) ) n 1 = n p Q(f)=np(p+(1-p))^{n-1}=np\tag6

如法炮制,当 f ( x ) = x d f(x)=x^{\underline d}

(7) Q ( f ) = n d p d k = 0 n d ( n d k ) p k ( 1 p ) n k = n d p d Q(f)=n^{\underline d}p^d\sum_{k=0}^{n-d} \binom {n-d} {k}p^{k}(1-p)^{n-k}=n^{\underline d}p^d\tag7

那么我们希望能求出 b i b_i 使得 f ( x ) = i = 0 m b i x i f(x)=\sum_{i=0}^m b_i x^{\underline i} ,这样 Q ( f ) = i = 0 m b i n i p i Q(f)=\sum_{i=0}^mb_in^{\underline i}p^i

不妨令 c i = i ! b i c_i=i!b_i ,那么
(8) f ( x ) = i = 0 m c i x i i ! = i = 0 m c i ( x i ) f(x)=\sum_{i=0}^m c_i\frac {x^{\underline i}} {i!}=\sum_{i=0}^m c_i \binom {x} {i}\tag8

(9) Δ f ( x ) = f ( x + 1 ) f ( x ) = i = 0 m c i ( ( x + 1 i ) ( x i ) ) = i = 0 m c i ( x i 1 ) \Delta f(x)=f(x+1)-f(x)=\sum_{i=0}^m c_i\biggl(\binom {x+1} {i}-\binom {x} {i}\biggr)=\sum_{i=0}^{m} c_i\binom {x} {i-1}\tag9

x = 0 x=0 则仅有一个组合数为1,即 Δ f ( 0 ) = c 1 \Delta f(0)=c_1

然后一阶差分的式子和原来的式子惊人地相似,那么我们可以继续利用差分推出 Δ k f ( 0 ) = c k \Delta^kf(0)=c_k

暴力差分的时间复杂度是 O ( m 2 ) O(m^2) ,这样就可以过了。但其实 k k 阶差分是可以继续优化的,通过找规律,我们会发现

(10) Δ k f ( 0 ) = i = 0 k ( 1 ) k i ( k i ) a i \Delta^kf(0)=\sum_{i=0}^k (-1)^{k-i}\binom k i a_i\tag{10}

NTT加速即可做到 O ( m log m ) O(m\log m)

Code

#include <cstdio>
using namespace std;
typedef long long ll;
const int maxn=20010,mod=998244353;
template <typename Tp> inline int getmin(Tp &x,Tp y){return y<x?x=y,1:0;}
template <typename Tp> inline int getmax(Tp &x,Tp y){return y>x?x=y,1:0;}
template <typename Tp> inline void read(Tp &x)
{
    x=0;int f=0;char ch=getchar();
    while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
    if(ch=='-') f=1,ch=getchar();
    while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    if(f) x=-x;
}
int n,m,p,nd=1,pd=1,ans,a[maxn],b[maxn],fac[maxn],inv[maxn];
int pls(int x,int y){return x+y>=mod?x+y-mod:x+y;}
int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
int power(int x,int y)
{
	int res=1;
	for(;y;y>>=1,x=(ll)x*x%mod)
	  if(y&1)
	    res=(ll)res*x%mod;
	return res;
}
int main()
{
	read(n);read(m);read(p);
	for(int i=0;i<=m;i++) read(a[i]);
	fac[0]=1;
	for(int i=1;i<=m;i++) fac[i]=(ll)fac[i-1]*i%mod;
	inv[m]=power(fac[m],mod-2);
	for(int i=m-1;~i;i--) inv[i]=(ll)inv[i+1]*(i+1)%mod;
	for(int i=0;i<=m;i++)
	{
		b[i]=(ll)a[0]*inv[i]%mod;
		for(int j=0;j<m-i;j++) a[j]=dec(a[j+1],a[j]);
	}
	for(int i=0;i<=m;i++)
	{
		ans=pls(ans,(ll)b[i]*nd%mod*pd%mod);
		nd=(ll)nd*(n-i)%mod;pd=(ll)pd*p%mod;
	}
	printf("%d\n",ans);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/As_A_Kid/article/details/86698729