[JZOJ5666]【GDOI2018Day2模拟4.18】法力风暴(分治NTT 模板)

Description

这里写图片描述
2 n 10 5 , 0 A i , k 10 9

Solution

注意到一次操作打出的伤害就是原来A的乘积减去操作后A的乘积

那么题目转化为求原来A的乘积减去最终A的乘积的期望

a [ i ] 最终被减去了 b [ i ]

那么最终期望为

E ( ( A [ i ] b [ i ] ) ) = 1 n k b [ i ] = k ( k ! b [ i ] ! × ( A [ i ] b [ i ] ) )

我们可以把 k ! n k 提出来,对每个i单独考虑
就变成
k ! n k b [ i ] = k ( ( A [ i ] b [ i ] ) b [ i ] ! )

设指数型生成函数

F i ( x ) = j 0 ( A [ i ] j ) x j j !

= j 0 A [ i ] x j j x j j ! = j 0 ( A [ i ] x j j ! + x × x j 1 ( j 1 ) ! )

注意到 j 0 x j j ! = e x

因为我们不关心整个序列是否收敛,我们关心的是具体项的系数
那么原式

= ( A [ i ] x ) e x

把所有i乘起来

e n x ( A [ i ] x )

现在要求的就是这个多项式 x k 项的系数

F ( x ) = ( A [ i ] x ) , G ( x ) = e n x ( A [ i ] x )

F ( x ) 可以用分治NTT处理,次数是n

我们要求的是 G ( x ) x k 这一项的系数
暴力卷积
G ( x ) [ x k ] = k ! n k j = 0 n n k j ( k j ) ! × F ( x ) [ x j ]

复杂度 O ( N log 2 N )

Code

#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <iostream>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fod(i,a,b) for(int i=a;i>=b;i--)
#define N 262144
#define LL long long
#define mo 998244353
using namespace std;
LL wi[2*N+5],a[2*N+5],ny[2*N+5],wg[2*N+5],a1[2*N+10],c[2*N+5],d[2*N+5];
int bit[2*N+5],n,m,L,fi[N],le[N],l2[N+5];
void prp(int num)
{
    fo(i,0,num-1) bit[i]=(bit[i>>1]>>1)|((i&1)<<(L-1));
    fo(i,0,num) wi[i]=wg[i*(N/num)];
}
inline LL md(LL x)
{
    return(x<0)?(x+mo):((x>=mo)?x-mo:x);
}
void NTT(LL *a,int pd,int num)
{
    prp(num);
    fo(i,0,num-1) if(i<bit[i]) swap(a[i],a[bit[i]]);
    int lim=num>>1,half=1;
    LL v;
    for(int m=2;m<=num;half=m,m<<=1,lim>>=1)
    {
        fo(i,0,half-1)
        {
            LL w=(pd==1)?wi[i*lim]:wi[num-i*lim];
            for(int j=i;j<num;j+=m)
            {
                v=(w*a[j+half])%mo;
                a[j+half]=md(a[j]-v);
                a[j]=md(a[j]+v);
            }
        }
    }
    if(pd<0) fo(i,0,num-1) a[i]=a[i]*ny[num]%mo;
}
void doit(int l,int r)
{
    if(l==r) return;
    int mid=(l+r)>>1;
    doit(l,mid),doit(mid+1,r);
    L=l2[le[l]+le[mid+1]];
    int rm=1<<L;
    fo(j,0,rm-1) 
    {
        c[j]=(j<=le[l]-1)?a1[fi[l]+j]:0;
        d[j]=(j<=le[mid+1]-1)?a1[fi[mid+1]+j]:0;
    }  
    NTT(c,1,rm),NTT(d,1,rm);
    fo(j,0,rm) c[j]=c[j]*d[j]%mo;
    NTT(c,-1,rm);
    le[l]=le[l]+le[mid+1]-1;
    fo(j,0,le[l]-1) a1[fi[l]+j]=c[j];
}
LL ksm(LL k,LL n)
{
    LL s=1;
    for(;n;k=k*k%mo,n>>=1) if(n&1) s=s*k%mo;
    return s;
}
int main()
{
    cin>>n>>m;
    LL s1=1;
    fo(i,1,n) scanf("%lld",&a[i]),a1[2*(i-1)]=a[i],a1[2*i-1]=998244352,fi[i]=2*(i-1),le[i]=2,s1=s1*a[i]%mo; 
    l2[1]=0;
    fo(i,1,18) l2[1<<i]=i;
    fod(i,N,1) if(!l2[i]) l2[i]=l2[i+1];
    LL vw=ksm(3,3808);
    ny[0]=wg[0]=1;
    fo(i,1,N) wg[i]=wg[i-1]*vw%mo,ny[i]=ksm(i,mo-2);
    doit(1,n);
    LL s=1,ans=0,yn=ksm(ksm(n,mo-2),m);
    fo(i,0,n)
    {
        if(i>m) break;
        ans=(ans+ksm(n,m-i)*s%mo*a1[i]%mo)%mo;
        s=s*(m-i)%mo;
    }
    ans=ans*yn%mo;
    printf("%lld\n",(s1-ans+mo)%mo);
}

猜你喜欢

转载自blog.csdn.net/hzj1054689699/article/details/80051184