(早上好,笔记在注释里。)
多项式卷积模板:
FFT:
#include<iostream> #include<cstdio> #include<cmath> using namespace std; const int N=3e6+10; double pi=acos(-1); int n,m; struct node{ double x,y; node(double a=0,double b=0){ x=a,y=b; } node operator + (node const &u) const{ return node(x+u.x,y+u.y); } node operator - (node const &u) const{ return node(x-u.x,y-u.y); } node operator * (node const &u) const{ return node(x*u.x-y*u.y,y*u.x+x*u.y); } }f[N]; int pos[N]; void fft(node *f,bool flag){ for(int i=0;i<n;i++){ if(i<pos[i])swap(f[i],f[pos[i]]); } for(int p=2;p<=n;p<<=1){ int len=p>>1; node fir(cos(2*pi/p),sin(2*pi/p)); if(!flag)fir.y*=-1; for(int k=0;k<n;k+=p){ node buf(1,0); for(int l=k;l<k+len;l++){ node tt=buf*f[len+l]; f[len+l]=f[l]-tt; f[l]=f[l]+tt; buf=buf*fir; } } } } int main() { scanf("%d%d",&n,&m); for(int i=0;i<=n;i++)scanf("%lf",&f[i].x); for(int i=0;i<=m;i++)scanf("%lf",&f[i].y); for(m+=n,n=1;n<=m;n<<=1); for(int i=0;i<n;i++){ pos[i]=(pos[i>>1]>>1)|((i&1)?n>>1:0); } fft(f,1); for(int i=0;i<n;i++)f[i]=f[i]*f[i]; fft(f,0); for(int i=0;i<=m;i++)printf("%d ",(int)(f[i].y/n/2+0.5)); return 0; } //FFT比较丢精度,如果需要卷积的多项式系数值域相差太大,就会卡精度 //三次变两次优化涉及的精度跨度上限更大,严重掉精度
NTT:
#include<iostream> #include<cstdio> using namespace std; const int N=3e6+10,mod=998244353,G=3; int n,m,pos[N]; long long f[N],g[N],invn,invG; long long pw(long long x,long long k){ long long num=1; while(k){ if(k&1)num=num*x%mod; x=x*x%mod; k>>=1; } return num; } void ntt(long long *f,bool flag){ for(int i=0;i<n;i++){ if(i<pos[i])swap(f[i],f[pos[i]]); } for(int p=2;p<=n;p<<=1){ int len=p>>1; long long fir=pw((flag?G:invG),(mod-1)/p); for(int i=0;i<n;i+=p){ long long bur=1; for(int l=i;l<i+len;l++){ long long tt=bur*f[l+len]%mod; f[l+len]=((f[l]-tt)%mod+mod)%mod; f[l]=(f[l]+tt)%mod; bur=bur*fir%mod; } } } } int main() { scanf("%d%d",&n,&m); for(int i=0;i<=n;i++)scanf("%lld",&f[i]); for(int i=0;i<=m;i++)scanf("%lld",&g[i]); for(m+=n,n=1;n<=m;n<<=1); for(int i=0;i<n;i++){ pos[i]=(pos[i>>1]>>1)|((i&1)?(n>>1):0); } invn=pw(n,mod-2),invG=pw(G,mod-2); ntt(f,1),ntt(g,1); for(int i=0;i<n;i++){ f[i]=f[i]*g[i]%mod; } ntt(f,0); for(int i=0;i<=m;i++){ printf("%lld ",f[i]*invn%mod); } return 0; } //数组需要开到2的幂次以上,不是两倍 //第一个单位根从1开始 //mod的大小开在最大系数以上防止模掉
多项式求逆:
#include<iostream> #include<cstdio> #define ll long long using namespace std; const int N=3e5+10,mod=998244353,G=3; int n,pos[N]; ll a[N],b[N],c[N],invG; ll pw(ll x,ll k){ ll num=1; while(k){ if(k&1)num=num*x%mod; x=x*x%mod; k>>=1; } return num; } void ntt(ll *f,int n,bool flag){ for(int i=0;i<n;i++)if(i<pos[i])swap(f[i],f[pos[i]]); for(int p=2;p<=n;p<<=1){ int len=p>>1; ll fir=pw((flag?G:invG),(mod-1)/p); for(int k=0;k<n;k+=p){ ll bur=1; for(int l=k;l<len+k;l++){ ll tt=f[l+len]*bur%mod; f[len+l]=((f[l]-tt)%mod+mod)%mod; f[l]=(f[l]+tt)%mod; bur=bur*fir%mod; } } } } void getinv(int now,ll *a,ll*b){ if(now==1){b[0]=pw(a[0],mod-2);return;} getinv((now+1)>>1,a,b); int goal=1; while(goal<(now<<1))goal<<=1; ll invn=pw(goal,mod-2); for(int i=0;i<goal;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(goal>>1):0); for(int i=0;i<now;i++)c[i]=a[i]; for(int i=now;i<goal;i++)c[i]=0; ntt(c,goal,1),ntt(b,goal,1); for(int i=0;i<goal;i++)b[i]=((2ll-c[i]*b[i]%mod)%mod+mod)%mod*b[i]%mod; ntt(b,goal,0); for(int i=0;i<now;i++)b[i]=b[i]*invn%mod; for(int i=now;i<goal;i++)b[i]=0; } int main(){ scanf("%d",&n); for(int i=0;i<n;i++)scanf("%lld",&a[i]); invG=pw(G,mod-2); getinv(n,a,b); for(int i=0;i<n;i++)printf("%lld ",b[i]); return 0; } //递归每一层本质对x^now取模,逆元数组b每次处理完要把被模掉的多余部分清空。
一些题目:
持续补完,请自己记得复习。