快速数论变换(NTT)

引入

快速傅里叶变换(FFT)的缺点进行了优化。
在计算多项式乘法(卷积)时,FFT设计三角函数、复数等很多恶心的东西,有着最大的缺点:精度问题,而在很多题目中往往需要进行取模,要求精度很高,FFT就不行了。

于是就有了快速数论变换

原根

FFT之所以可以实现,是利用了单位复根 ω 的周期性质, ω n n = 1 , ω n k = ω n k + n
通过这个性质,可以把FFT后续所有步骤全部推导出来。

NTT由于需要取模,根据模数,我们可以重新定义一个类似于单位复根的东西,使它的幂有周期性,那就是原根

原根: ω n n 1 ( m o d   p ) ,且没有 ( k = 1 , 2 , 3... , n 1 ) ω n k 1 ( m o d   p )
对于每个质数 p ,令 g p 1 1 ( m o d   p ) ,且 g k   m o d   p 都不为1, ( 1 k p 2 )
g p 1 n 就可以作为原根 ω n ,满足FFT中单位复根的一切性质。

将FFT中所有单位复根换位原根,就可以实现NTT了。

代码

//UOJ34
#include<cstdio>
#include<algorithm>
using namespace std;
const int MAXN=400005,MOD=998244353,G=3;

int pow_mod(int a,int b)
{
    int ret=1;
    while(b)
    {
        if(b&1)
            ret=(1LL*ret*a)%MOD;
        a=(1LL*a*a)%MOD;
        b>>=1;
    }
    return ret;
}

void NTT(int A[],int n,int mode)
{
    for(int i=0,j=0;i<n;i++)
    {
        if(i<j)
            swap(A[i],A[j]);
        int k=(n>>1);
        while(k&&(k&j))
        {
            j^=k;
            k>>=1;
        }
        j^=k;
    }
    for(int i=1;i<n;i<<=1)
    {
        int w1=pow_mod(G,(MOD-1)/(i<<1)),w=1;
        if(mode==-1)
            w1=pow_mod(w1,MOD-2);
        for(int j=0;j<i;j++,w=(1LL*w*w1)%MOD)
            for(int l=j,r=l+i;l<n;l+=(i<<1),r=l+i)
            {
                int temp=(1LL*A[r]*w)%MOD;
                A[r]=(A[l]-temp+MOD)%MOD;
                A[l]=(A[l]+temp)%MOD;
            }
    }
    if(mode==-1)
    {
        int inv=pow_mod(n,MOD-2);
        for(int i=0;i<n;i++)
            A[i]=(1LL*A[i]*inv)%MOD;
    }
}

void mul(const int A[],int l1,const int B[],int l2,int C[])
{
    static int tA[MAXN],tB[MAXN];
    int len=1;
    while(len<l1+l2-1)
        len<<=1;
    for(int i=0;i<len;i++)
        tA[i]=tB[i]=0;
    for(int i=0;i<l1;i++)
        tA[i]=A[i];
    for(int i=0;i<l2;i++)
        tB[i]=B[i];
    NTT(tA,len,1);
    NTT(tB,len,1);
    for(int i=0;i<len;i++)
        tA[i]=(1LL*tA[i]*tB[i])%MOD;
    NTT(tA,len,-1);
    for(int i=0;i<len;i++)
        C[i]=tA[i];
}

int A[MAXN],B[MAXN];

int main()
{
    int n,m;
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)
        scanf("%d",A+i);
    for(int i=0;i<=m;i++)
        scanf("%d",B+i);
    mul(A,n+1,B,m+1,A);
    for(int i=0;i<n+m;i++)
        printf("%d ",A[i]);
    printf("%d\n",A[n+m]);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/can919/article/details/79461300