FFT NTT 模板

NTT:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 2000050
#define ll long long
#define MOD 998244353
template<typename T>
inline void read(T&x)
{
    T f=1,c=0;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){c=10*c+ch-'0';ch=getchar();}
    x = f*c;
}
ll fastpow(ll x,int y)
{
    ll ret = 1;
    while(y)
    {
        if(y&1)ret=ret*x%MOD;
        x=x*x%MOD;
        y>>=1;
    }
    return ret;
}
int n,m,mx,to[2*N],lim=1,l;
void ntt(ll *a,int len,int k)
{
    for(int i=0;i<len;i++)
        if(i<to[i])swap(a[i],a[to[i]]);
    for(int i=1;i<len;i<<=1)
    {
        ll w0 = fastpow(3,(MOD-1)/(i<<1));
        for(int j=0;j<len;j+=(i<<1))
        {
            ll w = 1;
            for(int o=0;o<i;o++,w=w*w0%MOD)
            {
                ll w1 = a[j+o],w2 = a[j+o+i]*w%MOD;
                a[j+o] = (w1+w2)%MOD;
                a[j+o+i] = ((w1-w2)%MOD+MOD)%MOD;
            }
        }
    }
    if(k==-1)
        for(int i=1;i<(lim>>1);i++)swap(c[i],c[lim-i]);
}
ll a[2*N],b[2*N],c[2*N];
int main()
{
    read(n),read(m);mx = max(n,m);
    for(int i=0;i<=n;i++)read(a[i]);
    for(int i=0;i<=m;i++)read(b[i]);
    while(lim<2*mx)lim<<=1,l++;
    for(int i=1;i<lim;i++)to[i]=((to[i>>1]>>1)|((i&1)<<(l-1)));
    ntt(a,lim,1),ntt(b,lim,1);
    for(int i=0;i<lim;i++)c[i]=a[i]*b[i]%MOD;
    ntt(c,lim,-1);
    ll inv = fastpow(lim,MOD-2);
    for(int i=0;i<lim;i++)c[i]=c[i]*inv%MOD;
    for(int i=0;i<=n+m;i++)printf("%lld ",c[i]);
    puts("");
    return 0;
}

FFT:

#include<cmath>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 2000050
#define ll long long
const double Pi = acos(-1.0);
template<typename T>
inline void read(T&x)
{
    T f=1,c=0;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){c=10*c+ch-'0';ch=getchar();}
    x = f*c;
}
struct cp
{
    double x,y;
    cp(){}
    cp(double x,double y):x(x),y(y){}
};
cp operator + (cp &a,cp &b)
{
    return cp(a.x+b.x,a.y+b.y);
}
cp operator - (cp &a,cp &b)
{
    return cp(a.x-b.x,a.y-b.y);
}
cp operator * (cp &a,cp &b)
{
    return cp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
int n,m,mx,to[2*N],lim=1,l;
void fft(cp *a,int len,int k)
{
    for(int i=0;i<len;i++)
        if(i<to[i])swap(a[i],a[to[i]]);
    for(int i=1;i<len;i<<=1)
    {
        cp w0(cos(Pi/i),k*sin(Pi/i));
        for(int j=0;j<len;j+=(i<<1))
        {
            cp w(1,0);
            for(int o=0;o<i;o++,w=w*w0)
            {
                cp w1 = a[j+o],w2 = a[j+o+i]*w;
                a[j+o] = w1+w2;
                a[j+o+i] = w1-w2;
            }
        }
    }
}
cp a[2*N],b[2*N],c[2*N];
int main()
{
    read(n),read(m);mx = max(n,m);
    for(int i=0;i<=n;i++)read(a[i].x);
    for(int i=0;i<=m;i++)read(b[i].x);
    while(lim<2*mx)lim<<=1,l++;
    for(int i=1;i<lim;i++)to[i]=((to[i>>1]>>1)|((i&1)<<(l-1)));
    fft(a,lim,1),fft(b,lim,1);
    for(int i=0;i<lim;i++)c[i]=a[i]*b[i];
    fft(c,lim,-1);
    for(int i=0;i<=n+m;i++)
        printf("%lld ",(ll)(c[i].x/lim+0.5));
    puts("");
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/LiGuanlin1124/p/10258662.html