FFT(快速傅里叶变换)NTT(快速数论变换)模板

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011056504/article/details/79048108

FFT

模板题:51nod 1028大数乘法

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define db double
#define N 401000
using namespace std;
struct node{
    db x,y;
    node friend operator +(node a,node b){node c;c.x=a.x+b.x;c.y=a.y+b.y;return c;}
    node friend operator -(node a,node b){node c;c.x=a.x-b.x;c.y=a.y-b.y;return c;}
    node friend operator *(node a,node b){node c;c.x=a.x*b.x-a.y*b.y;c.y=a.x*b.y+a.y*b.x;return c;}
};
int len,lg,n,m;
long long ans[N];
db c[N],pi;
node a[N],b[N],q[N],w[N],d[N];
void init()
{
    char c=getchar();m=-1;
    for(;c>='0'&&c<='9';c=getchar()) b[++m].x=c-48;
}
void DFT(node *a,int sig)
{
    fo(i,0,len-1)
    {
        int p=0;
        for(int j=0,tp=i;j<lg;j++,tp/=2) p=(p<<1)+(tp%2);
        d[p]=a[i];
    }
    for(int m=2;m<=len;m*=2)
    {
        int half=m/2;
        fo(i,0,half-1)
        {
            node w;
            w.x=cos(i*sig*pi/half),w.y=sin(i*sig*pi/half);
            for(int j=i;j<len;j+=m)
            {
                node u=d[j],v=d[j+half]*w;
                d[j]=u+v;
                d[j+half]=u-v;
            }
        }
    }
    if(sig==-1) fo(i,0,len-1) d[i].x/=len;
    fo(i,0,len-1) a[i]=d[i];
}
void FFT(node *a,node *b,db *c)
{
    fo(i,0,len-1) q[i]=a[i],w[i]=b[i];
    DFT(q,1);DFT(w,1);
    fo(i,0,len-1) q[i]=q[i]*w[i];
    DFT(q,-1);
    fo(i,0,len-1) c[i]=q[i].x;
}
int main()
{
    freopen("fft.in","r",stdin);
    freopen("fft.out","w",stdout);
    init();
    n=m;fo(i,0,n) a[i]=b[i];
    init();
    for(len=1;len<n+m;len*=2);
    lg=log2(len);pi=acos(-1);
    fo(i,0,n/2) swap(a[i],a[n-i]);
    fo(i,0,m/2) swap(b[i],b[m-i]);
    fo(i,m+1,n) b[i].x=b[i].y=0;
    FFT(a,b,c);
    n=len;
    fo(i,0,n-1) ans[i]=round(c[i]);
    fo(i,0,n-1) ans[i+1]+=ans[i]/10,ans[i]%=10;
    for(;ans[n]==0;n--);
    for(int i=n;i>=0;i--) printf("%lld",ans[i]);
}

NTT

模板题:51nod 1348 乘积之和

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define N 70100
#define ll long long
#define db double
#define mod 100003
#define mo1 998244353ll
#define mo2 1004535809ll
using namespace std;
int n,ac,len,v[N];
ll c[20][2][N],q[N],w[N],d[N],b[20][N],a[2][N],W[N];
ll M=mo1*mo2,ni1,ni2;
db lg;
ll mi(ll a,ll b,ll mo)
{
    a%=mo;ll jy=1;
    for(;b;b/=2,a=a*a%mo) if(b%2==1) jy=jy*a%mo;
    return jy;
}
ll mul(ll a,ll b)
{
    ll jy=0;
    for(;b;b/=2,a=(a+a)%M) if(b%2==1) jy=(jy+a)%M;
    return jy;
}
void DFT(ll *a,int sig,ll mo)
{
    fo(i,0,len-1)
    {
        int p=0;
        for(int j=0,tp=i;j<lg;j++,tp/=2) p=(p<<1)+(tp%2);
        d[p]=a[i];
    }
    for(int m=2;m<=len;m*=2)
    {
        int half=m/2,tmp=len/m;
        fo(i,0,half-1)
        {
            ll w=(sig==1)?W[i*tmp]:W[len-i*tmp];
            for(int j=i;j<len;j+=m)
            {
                ll u=d[j],v=d[j+half]*w%mo;
                d[j]=(u+v)%mo;
                d[j+half]=(u-v+mo)%mo;
            }
        }
    }
    fo(i,0,len-1) a[i]=d[i];
}
void NTT(ll *a,ll *b,ll *c,ll mo)
{
    W[0]=1;W[1]=mi(3,(mo-1)/len,mo);
    fo(i,2,len) W[i]=W[i-1]*W[1]%mo;
    fo(i,0,len-1) q[i]=a[i],w[i]=b[i];
    DFT(q,1,mo);DFT(w,1,mo);
    fo(i,0,len-1) q[i]=q[i]*w[i]%mo;
    DFT(q,-1,mo);
    ll inv=mi(len,mo-2,mo);
    fo(i,0,len-1) c[i]=q[i]*inv%mo;
}
void calc(int l,int r,int z)
{
    if(l==r)
    {
        b[z][0]=1;
        b[z][1]=v[l];
        return;
    }
    int m=(l+r)/2;
    calc(l,m,z+1);
    fo(i,0,m-l+1) c[z][0][i]=b[z+1][i];
    calc(m+1,r,z+1);
    fo(i,0,r-m) c[z][1][i]=b[z+1][i];
    for(len=1;len<r-l+3;len*=2);
    lg=log2(len);
    fo(i,m-l+2,len) c[z][0][i]=0;
    fo(i,r-m+1,len) c[z][1][i]=0;

    NTT(c[z][0],c[z][1],a[0],mo1);

    NTT(c[z][0],c[z][1],a[1],mo2);

    fo(i,0,len)
    {
        b[z][i]=((mul(a[0][i],ni1)+mul(a[1][i],ni2))%M)%mod;
    }
}
int main()
{
    scanf("%d%d",&n,&ac);
    fo(i,1,n) scanf("%d",&v[i]),v[i]%=mod;
    ni1=mi(mo2,mo1-2,mo1)*mo2;
    ni2=mi(mo1,mo2-2,mo2)*mo1;
    calc(1,n,0);
    for(;ac;ac--)
    {
        int x;scanf("%d",&x);
        printf("%lld\n",b[0][x]%mod);
    }
}

猜你喜欢

转载自blog.csdn.net/u011056504/article/details/79048108
今日推荐