[Template] NTT and three-modulus NTT

I wrote a note about FFT before. We know that FFT is a transformation in the complex number domain.
And it has been proved by mathematicians that DFT is the only transformation in the complex number domain that satisfies the property of circular convolution.

In OI, we often encounter the problem of taking the modulo of xxxx, which inspires us to find such a transformation in the sense of modulo operation.
Then we found that there is a magical thing, the original root \(g\) , this thing is equivalent to the unit complex root \(-e^{\frac{2\pi i}{n}}\) in the modular sense .

So let's preprocess the power and inverse of \(g\) , and then change the code of fft, and there will be a fast number theory transformation ntt is
too lazy to write the code directly:

void getwn(){ //预处理原根的幂和逆元
    int x=qpow(3,p-2);
    for(int i=0;i<20;++i){
        wn[i]=qpow(3,(p-1)/(1<<i));
        inv[i]=qpow(x,(p-1)/(1<<i));
    }
}
void ntt(int *y,bool f){ rev(y); //翻转代码和fft无异
    for(int m=2,id=1;m<=n;m<<=1,++id){ //id用来记录转到第几下了
        for(int k=0;k<n;k+=m){
            int w=1,wm=f?wn[id]:inv[id]; //如果是dft就用幂, idft就用幂的逆元
            for(int j=0;j<m/2;++j){
                                //这里跟fft一样, 不过要对p取模
                int u=y[k+j]%p,t=1ll*w*y[k+j+m/2]%p;
                y[k+j]=u+t; if(y[k+j]>p) y[k+j]-=p;
                y[k+j+m/2]=u-t; if(y[k+j+m/2]<0) y[k+j+m/2]+=p;
                w=1ll*w*wm%p;
            }
        }
    }
    if(!f){
        int x=qpow(n,p-2);
        for(int i=0;i<n;++i)
            y[i]=1ll*y[i]*x%p;
    }
}

It seems to be almost the same~ But this requires us to find a number that is easy to find with its original root. For example, the famous uoj number: 998244353, 1004535809 and 469762049, etc., the original roots of these three numbers are all 3~ It
seems that when I saw the modulus at that time Not 1e9+7 will generally think of ntt. In order to prevent this, vfk adopts 998244353 for the modulus, and now we can see the effect is good.

However, some crazy people just want to use 1e9+7 as the modulus of ntt, and even there is a case where the prime number can not be modulo!
So how do we solve any modulus ntt? We can use the split coefficient ntt or the three-modulus ntt. Here is an introduction to the three-modulus ntt.
For the general data range, \(n\leq10^5, a_i\leq10^9\) , this may reach 10^5*10^{9^2}=10^{23 } level.
So we can find the ntt-friendly number of three products \(>10^{23}\)
, and then find a way to combine them separately. If the answer is ans, then we can get it after doing ntt three times. The following three persimmons.
\[ \left\{\begin{matrix} ans\equiv a_1(\mod m_1)\\ ans\equiv a_2(\mod m_2)\\ ans\equiv a_3(\mod m_3) \end{ matrix}\right. \]
We combine the first two persimmons through the Chinese remainder theorem, we can get
\[ \left\{\begin{matrix} ans\equiv A(\mod M)\\ ans\equiv a_3(\ mod m_3) \end{matrix}\right. \]
Where, \(M=m_1*m_2\)
So we set \(ans=kM+A\) ,
\[ kM+A\equiv a_3(\mod m_3) \\ k=(a_3-A)*M^{-1} (\mod m_3) \]
So we find\(k\) Then substitute back to \(ans=kM+A\) to find the result of taking the modulus of any modulus.

When the Chinese remainder theorem is combined, direct multiplication can explode long long, so we need to use \(O(1)\) fast multiplication~

The last wave of code below: luogu4245 [Template] MTT ,
I think my code style is a bit ugly qwq

#include <cstdio>
#include <cstring>
#include <algorithm>
typedef long long LL;
const int N=600020,p0=469762049,p1=998244353,p2=1004535809;
const LL M=1ll*p0*p1;
int wn[20],nw[20],rev[N],n,lg,p;
int qpow(int a,int b,int p,int s=1){
    for(;b;b>>=1,a=1ll*a*a%p)
        if(b&1) s=1ll*s*a%p;
    return s;
}
LL mul(LL a,LL b,LL p){ a%=p; b%=p;
    return (a*b-(LL)((long double)a*b/p)*p+p)%p;
}
void calcw(int p){
    int x=qpow(3,p-2,p);
    for(int i=0;i<20;++i){
        wn[i]=qpow(3,(p-1)/(1<<i),p);
        nw[i]=qpow(x,(p-1)/(1<<i),p);
    }
}
void init(){
    for(int i=0;i<n;++i)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<lg);
}
void ntt(int *y,bool f,int p){ calcw(p);
    for(int i=0;i<n;++i) if(i<rev[i]) std::swap(y[i],y[rev[i]]);
    for(int m=2,id=1;m<=n;m<<=1,++id){
        for(int k=0;k<n;k+=m){
            int w=1,wm=f?wn[id]:nw[id];
            for(int j=0;j<m>>1;++j){
                int &a=y[k+j]; int &b=y[k+j+m/2];
                int u=a%p,t=1ll*w*b%p;
                a=u+t; if(a>p) a-=p;
                b=u-t; if(b<0) b+=p;
                w=1ll*w*wm%p;
            }
        }
    } int x=qpow(n,p-2,p);
    if(!f) for(int i=0;i<n;++i) y[i]=1ll*y[i]*x%p;
}
char c1[N],c2[N]; int a[N],b[N],c[N],d[N],ans[3][N];
int main(){
    int l1,l2; scanf("%d%d%d",&l1,&l2,&p);
    for(int i=0;i<=l1;++i) scanf("%d",&a[i]),a[i]%=p;
    for(int i=0;i<=l2;++i) scanf("%d",&b[i]),b[i]%=p;
    for(n=1;n<l1||n<l2;n<<=1,++lg); n<<=1; init();
    std::copy(a,a+n,c); std::copy(b,b+n,d);
    ntt(c,1,p0); ntt(d,1,p0);
    for(int i=0;i<n;++i) ans[0][i]=1ll*c[i]*d[i]%p0;
    std::copy(a,a+n,c); std::copy(b,b+n,d);
    ntt(c,1,p1); ntt(d,1,p1);
    for(int i=0;i<n;++i) ans[1][i]=1ll*c[i]*d[i]%p1;
    std::copy(a,a+n,c); std::copy(b,b+n,d);
    ntt(c,1,p2); ntt(d,1,p2);   
    for(int i=0;i<n;++i) ans[2][i]=1ll*c[i]*d[i]%p2;
    ntt(ans[0],0,p0); ntt(ans[1],0,p1); ntt(ans[2],0,p2);
    for(int i=0;i<n;++i){
        LL A=mul(1ll*ans[0][i]*p1%M,qpow(p1%p0,p0-2,p0),M)
            +mul(1ll*ans[1][i]*p0%M,qpow(p0%p1,p1-2,p1),M);
        if(A>M) A-=M;
        LL k=((ans[2][i]-A)%p2+p2)%p2*qpow(M%p2,p2-2,p2)%p2;
        a[i]=1ll*(k%p)*(M%p)%p+A%p;
        if(a[i]>p) a[i]-=p;
    }
    for(int i=0;i<=l1+l2;++i) printf("%d ",a[i]);
}

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=324793760&siteId=291194637