【模板】FFT&NTT

Tblog:

https://oi.men.ci/fft-notes/

https://oi.men.ci/fft-to-ntt/

FFT:

#include<bits/stdc++.h>
using namespace std;
typedef complex<double> D;
const double PI=acos(-1);
const int N=4e6+20;
int n,n1,n2,zws;
int a[N],b[N],res[N];
D c1[N],c2[N],omega[N],iomega[N];
inline void mread(int &rx)
{
    int fx=1;char c=getchar();
    rx=0;
    while(c<48||c>57)
    {
        if(c=='-') fx=-1;
        c=getchar();
    }
    while(c>=48&&c<=57)
    {
        rx=rx*10+c-48;
        c=getchar();
    }
    rx*=fx;
}
inline void init()
{
    int i,j;
    for(i=0;i<n;i++) omega[i]=D(cos(2*PI/n*i),sin(2*PI/n*i));
    for(i=0;i<n;i++) iomega[i]=D(cos(-2*PI/n*i),sin(-2*PI/n*i));
}
inline void FFT(D *w,int flag)
{
    int i,j,t,len,m;
    D nw,bw;
    for(i=0;i<n;i++)
    {
        t=0;
        for(j=0;j<zws;j++) if(i&(1<<j)) t|=(1<<(zws-j-1));
        if(i>t) std::swap(w[i],w[t]);
    }
    for(len=2;len<=n;len<<=1)
    {
        m=(len>>1);
        for(i=0;i<n;i+=len)
        {
            for(j=0;j<m;j++)
            {
                nw= flag==1? omega[n/len*j]:iomega[n/len*j];
                //printf("%.2lf %.2lf\n",omega[n/len*j].real(),iomega[n/len*j].real());
                bw=nw*w[i+m+j];
                w[i+m+j]=w[i+j]-bw;
                w[i+j]+=bw;
            }
        }
    }
}
int main()
{
    //freopen("test.in","r",stdin);
    int i,j;
    mread(n1);mread(n2);
    n1++;n2++;
    for(i=0;i<n1;i++) mread(a[i]);
    for(i=0;i<n2;i++) mread(b[i]);
    for(i=0;i<n1;i++) c1[i].real(a[i]);
    for(i=0;i<n2;i++) c2[i].real(b[i]); 
    for(n=1,zws=0;n<n1+n2;n<<=1,zws++);
    init();
    FFT(c1,1);FFT(c2,1);
    for(i=0;i<n;i++) c1[i]*=c2[i];
    FFT(c1,-1);
    for(i=0;i<n1+n2-1;i++) c1[i]/=n;
    for(i=0;i<n1+n2-1;i++) res[i]=floor(c1[i].real()+0.5); 
    for(i=0;i<n1+n2-1;i++) printf("%d ",res[i]);
    return 0;
}

NTT:

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL N=4e6+20,Mod=998244353,G=3,Gi=332748118;
LL n,n1,n2,mws,ny;
LL c1[N],c2[N];
inline void mread(LL &rx)
{
    LL fx=1;char c=getchar();
    rx=0;
    while(c<48||c>57)
    {
        if(c=='-') fx=-1;
        c=getchar();
    }
    while(c>=48&&c<=57)
    {
        rx=rx*10+c-48;
        c=getchar();
    }
    rx*=fx;
}
inline LL mmul(LL a,LL b){ return (a*b)%Mod;}
inline LL madd(LL a,LL b){ return (a+b)%Mod;}
inline LL msub(LL a,LL b){ return ((a-b)%Mod+Mod)%Mod;}
inline LL mquery(LL x,LL bs)
{
    LL rans=1;
    while(bs>0)
    {
        if(bs&1LL) rans=mmul(rans,x);
        x=mmul(x,x);
        bs>>=1;
    }
    return rans;
}
inline void NTT(LL *a,LL flag)
{
    LL i,j,len,m,t,bw,w,der;
    for(i=0;i<n;i++)
    {
        t=0;
        for(j=0;j<mws;j++) if(i&(1LL<<j)) t|=(1LL<<(mws-j-1));
        if(i<t) swap(a[i],a[t]); 
    }
    for(len=2;len<=n;len<<=1)
    {
        m=(len>>1);
        if(flag==1) bw=mquery(G,(Mod-1)/len);
        else bw=mquery(Gi,(Mod-1)/len);
        for(i=0;i<n;i+=len)
        {
            for(j=0,w=1;j<m;j++,w=mmul(w,bw))
            {
                der=mmul(w,a[i+m+j]);
                a[i+m+j]=msub(a[i+j],der);
                a[i+j]=madd(a[i+j],der);
            }
        }
    }
}
int main()
{
    //freopen("test.in","r",stdin);
    LL i,j;
    mread(n1);mread(n2);
    n1++;n2++;
    for(i=0;i<n1;i++) mread(c1[i]);
    for(i=0;i<n2;i++) mread(c2[i]);
    for(n=1,mws=0;n<n1+n2-1;n<<=1,mws++);
    NTT(c1,1);NTT(c2,1);
    for(i=0;i<n;i++) c1[i]=mmul(c1[i],c2[i]);
    NTT(c1,-1);
    ny=mquery(n,Mod-2);
    for(i=0;i<n1+n2-1;i++)
    {
        printf("%lld ",mmul(c1[i],ny));
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/spacevortex/p/10220738.html