[Polynomial algorithm] (Part 3) MTT arbitrary modulus FFT / NTT study notes

Other polynomial algorithm Portal:

[Polynomial algorithm] (Part 1) FFT Fast Fourier Transform Learning Notes

[Polynomial algorithm] (Part 2) NTT fast number theory transformation study notes

[Polynomial algorithm] (Part 4) FWT Fast Walsh Transform Learning Notes

[Polynomial algorithm] (Part 5) Divide and Conquer FFT study notes


\(3.Hard-MTT\)

definition

  • MTT\((Maoxiao\ Theoretic\ Transforms)\)

    Chinese name:I do not know, the above makes up the English name is

    (Most TLE Transforms)


\ (Q: \) now learned FFT and NTT, MTT then what is? What is the use?

\(A:\)Dayong

If now the convolution polynomials request two integers, the sequence length \ (\ LE10. 5 ^ \ n-) , polynomial coefficients \ (A_i, B_i \ ^. 9 Le 10 \) , the answer to \ (p \ le 10 ^ 9 \) modulo.

Then you will find, in the course of operation range can reach \ (10 ^ {23} \) level! FFT will use precision bombing, while NTT because of the nature of the modulus and useless.

You can choose high-precision, high-precision but not only difficult to achieve, is relatively low efficiency, and python, javaand other high-precision comes with language in some events also prohibited.

Then we need to be using the MTT operation.


analysis

MTT has \ (2 \) method, one is three NTT modulus, then the coefficient is removed FFT.

Wherein NTT excellent precision, but a larger constant, while the opposite FFT.

These two algorithms are described below.


NTT three modulus

The main idea is to use \ (3 \) that satisfy NTT properties \ (10 ^ 9 \) level for NTT modulus, to give \ (3 \) sequence, known by the Chinese remainder theorem, as for the range \ (10 ^ {23} <10 27 ^ {} \) , so we can by these (3 \) \ every number sequences determined.

About Select modulus, can write their own programs count, it can also look-up table, is recommended Miskcoo big table

As used herein, \ (3 \) a deep-fried without the addition intnumber: \ (469762049,998244353,1004535809 \)

It \ (3 \) the number of primitive roots are \ (3 \) , very convenient.

Assuming that finally got \ (3 \) sequences: \ (A, B, C \) , and now want to restore the first \ (i \) the answer to \ (the X-\) , the question becomes a congruence equation:
\ [\ begin {cases} \ begin {equation} \ begin {split} x \ equiv A_i \ pmod {p_1} \\ x \ equiv B_i \ pmod {p_2} \\ x \ equiv C_i \ pmod {p_3} \ end {split} \ end {equation}
\ end {cases} \] If the direct use of the Chinese remainder Theorem combined, then need to use int128or precision, both less convenient.

We can use EXCRT (expanding Chinese Remainder Theorem) method:
\ [\ the begin {Equation} \ the begin {Split} A_i + k_1p_1 & = B_i + k_2p_2 \\ A_i + k_1p_1 & \ equiv B_i \ PMOD {P_2} \\ k_1 & \ equiv frac {B_i-A_i} {p_1
} \ \ pmod {p_2} \ end {split} \ end {equation} \] before you get \ (2 \) Solutions term \ (X = A_i + k_1p_1 \) , followed by and \ (3 \) key combined:
\ [\ the begin {Equation} \ the begin {Split} X + k_3p_1p_2 & = C_i + k_4p_3 \\ X + k_3p_1p_2 & \ equiv C_i \ PMOD {P_3} \\ K_3 & \ equiv \ FRAC { c_i-x} {p_1p_2} \
pmod {p_3} \ end {split} \ end {equation} \] so we determined the \ (3 \) General solution term \ (X '= X + k_3p_1p_2 \) , then The answer is \ (x '\ mod {p } \)


Code

In summary, we need to do \ (3 \) times NTT, namely \ (9 \) times DFT (IDFT), constant large (very large,I write poor), Please note the constant optimization.

Example: Luogu P4245 [template] NTT any modulus

#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
#define rint register int
typedef long long ll;

//Having A Daydream...

char In[1<<20],*p1=In,*p2=In,Ch;
#define Getchar (p1==p2&&(p2=(p1=In)+fread(In,1,1<<20,stdin),p1==p2)?EOF:*p1++)
inline int Getint()
{
    register int x=0;
    while(!isdigit(Ch=Getchar));
    for(;isdigit(Ch);Ch=Getchar)x=x*10+(Ch^48);
    return x;
}

char Out[22222222],*Outp=Out,St[22],*Tp=St;
inline void Putint(int x)
{
    do *Tp++=x%10^48;while(x/=10);
    do *Outp++=*--Tp;while(St!=Tp);
}

inline ll Pow(ll a,ll b,ll p)
{
    ll Res=1;
    for(a%=p;b;b>>=1,a=a*a%p)
        if(b&1)Res=Res*a%p;
    return Res;
}

int r[1<<18];
namespace Poly
{
    //const int p[3]={469762049,998244353,1004535809};
    #define Add(a,b) (((a)+(b))>=p?(a)+(b)-p:(a)+(b))

    void NTT(int n,int *A,int p,int g)
    {
        for(rint i=0;i<n;++i)if(i<r[i])std::swap(A[i],A[r[i]]);
        for(rint i=2,h=1;i<=n;i<<=1,h<<=1)
            for(rint j=0,Rs=Pow(g,(p-1)/i,p);j<n;j+=i)
                for(rint k=0,Rt=1;k<h;++k,Rt=(ll)Rt*Rs%p)
                {
                    int Tmp=(ll)A[j+h+k]*Rt%p;
                    A[j+h+k]=Add(A[j+k],p-Tmp),A[j+k]=Add(A[j+k],Tmp);
                }
    }

    int A[1<<18],B[1<<18];
    void Multiply(int n,int *F,int *G,int p,int *S)
    {
        memcpy(A,F,n*sizeof(int));
        memcpy(B,G,n*sizeof(int));
        NTT(n,A,p,3),NTT(n,B,p,3);
        for(rint i=0;i<n;++i)A[i]=(ll)A[i]*B[i]%p;
        NTT(n,A,p,Pow(3,p-2,p));
        int In=Pow(n,p-2,p);
        for(rint i=0;i<n;++i)S[i]=(ll)A[i]*In%p;
    }
}

int n,m,p,F[1<<18],G[1<<18],S[3][1<<18];
const int P[]={469762049,998244353,1004535809};

int main()
{
    n=Getint(),m=Getint(),p=Getint();
    for(rint i=0;i<=n;++i)F[i]=Getint();
    for(rint i=0;i<=m;++i)G[i]=Getint();
    for(m=n+m,n=1;n<=m;n<<=1);
    for(rint i=0,l=(int)log2(n);i<n;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    for(rint i=0;i<3;++i)Poly::Multiply(n,F,G,P[i],S[i]);//计算F*G mod P[i],储存在S[i]
    for(rint i=0;i<=m;++i)
    {
        ll x=S[0][i]+((S[1][i]-S[0][i]+P[1])*Pow(P[0],P[1]-2,P[1])%P[1])*P[0];//前2项通项
        ll xs=(x%p+(S[2][i]-x%P[2]+P[2])*Pow((ll)P[0]*P[1],P[2]-2,P[2])%P[2]*P[0]%p*P[1]%p)%p;
        Putint(xs),*Outp++=i==m?'\n':' ';
    }
    return fwrite(Out,1,Outp-Out,stdout),0;
}

Code length 2.40KB

With time 4.21s

Memory 12.87MB

Max Case 522ms

The \ (10 ^ 5 \) of MTT time and I \ (10 ^ 6 \) NTT almost time. . . This algorithm may be faster to catch up with \ (O (nlog ^ 2n) \) the


FFT coefficient demolition

Set \ (M \) is a constant, the coefficients of the polynomial for each split into \ (A * M + B \ ) form (two polynomials corresponding to each \ (A_1, B_1 | A_2, B_2 \) ), are:
\ [(A_1M + B_1) (
A_2M + B_2) = A_1A_2M ^ 2 + (A_1B_2 + A_2B_1) M + B_1B_2 \] we only need to calculate \ (A_1A_2, A_1B_2, A_2B_1, B_1B_2 \) , then the addition can be get an answer.

When \ (M = \ sqrt P \ ) , the above \ (4 \) entries are \ (O (P) \) level, so the range for FFT in \ (10 ^ {14} \) level, does not fried.

(What? \ (A_1A_2M ^ 2 \) instead of \ (O (P ^ 2) \) level of it?)

In fact, we can first calculate \ (A_1A_2 \) , and finally \ (M ^ 2 \) ride up when modulo like.

If you calculate \ (4 \) convolution, which will require \ (12 \) times DFT (which would not be slower than NTT?)

Pretreatment \ (A_1, A_2, B_1, B_2 \) of the DFT values may be optimized to \ (7 \) times DFT:

DFT(\(A_1\)),DFT(\(A_2\)),DFT(\(B_1\)),DFT(\(B_2\)),IDFT(\(A_1A_2\)),IDFT(\(A_1B_2+A_2B_1\)),IDFT(\(B_1B_2\))

\ (Q: \) This is not still very slow?

In fact, we can continue to optimize down, using the combined DFT way you can optimize the DFT to \ (4 \) times (as detailed FFT study notes at the bottom)

Wherein \ (4 \) times the DFT optimization \ (2 \) times, \ (3 \) times to optimize IDFT \ (2 \) times.

So that you can run fast. (In fact, can be optimized to " \ (3.5 \) " sub-DFT, but the effect is not obvious and complex, 2016 Team of paper myy see "Fast Fourier Transform Revisited")


Code (no optimized version):

According to the above ideas written on it

Example: Luogu P4245 [template] NTT any modulus

( \ (7 \) times DFT, personally I feel better write)

// luogu-judger-enable-o2
#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
#define rint register int
typedef long long ll;
typedef long double ld;

//Having A Daydream...

char In[1<<20],*p1=In,*p2=In,Ch;
#define Getchar (p1==p2&&(p2=(p1=In)+fread(In,1,1<<20,stdin),p1==p2)?EOF:*p1++)
inline int Getint()
{
    register int x=0;
    while(!isdigit(Ch=Getchar));
    for(;isdigit(Ch);Ch=Getchar)x=x*10+(Ch^48);
    return x;
}

char Out[22222222],*Outp=Out,St[22],*Tp=St;
inline void Putint(int x)
{
    do *Tp++=x%10^48;while(x/=10);
    do *Outp++=*--Tp;while(St!=Tp);
}

const double Eps=1e-8,Pi=std::acos(-1),e=std::exp(1);
struct Complex
{
    ld x,y;
    inline Complex operator+(const Complex &o)const{return (Complex){x+o.x,y+o.y};}
    inline Complex operator-(const Complex &o)const{return (Complex){x-o.x,y-o.y};}
    inline Complex operator*(const Complex &o)const{return (Complex){x*o.x-y*o.y,x*o.y+y*o.x};}
    inline Complex operator/(const ld k)const{return (Complex){x/k,y/k};}
    inline Complex Conj(){return (Complex){x,-y};}
}Ome[1<<18],Inv[1<<18];

int r[1<<18];
namespace Poly
{
    void Pre(int n)
    {
        for(rint i=0,l=(int)log2(n);i<n;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
        for(rint i=0;i<n;++i)
        {
            ld x=std::cos(2*Pi*i/n),y=std::sin(2*Pi*i/n);
            Ome[i]=(Complex){x,y},Inv[i]=(Complex){x,-y};
        }
    }

    void FFT(int n,Complex *A,Complex *T)
    {
        for(rint i=0;i<n;++i)if(i<r[i])std::swap(A[i],A[r[i]]);
        for(rint i=2;i<=n;i<<=1)
            for(rint j=0,h=i>>1;j<n;j+=i)
                for(rint k=0;k<h;++k)
                {
                    Complex Tmp=A[j+h+k]*T[n/i*k];
                    A[j+h+k]=A[j+k]-Tmp,A[j+k]=A[j+k]+Tmp;
                }
    }

    Complex A1[1<<18],B1[1<<18],A2[1<<18],B2[1<<18];
    Complex A[1<<18],B[1<<18],C[1<<18];
    void MTT(int n,int p,int *F,int *G,int *S)
    {
        //这里为了方便直接设M=2^15=32768
        for(rint i=0;i<n;++i)
        {
            A1[i].x=F[i]>>15,B1[i].x=F[i]&0x7FFF;
            A2[i].x=G[i]>>15,B2[i].x=G[i]&0x7FFF;
        }
        FFT(n,A1,Ome),FFT(n,B1,Ome),FFT(n,A2,Ome),FFT(n,B2,Ome);
        for(rint i=0;i<n;++i)
        {
            A[i]=A1[i]*A2[i];
            B[i]=A1[i]*B2[i]+A2[i]*B1[i];
            C[i]=B1[i]*B2[i];
        }
        FFT(n,A,Inv),FFT(n,B,Inv),FFT(n,C,Inv);
        for(rint i=0;i<n;++i)
        {
            ll Av=(ll)round(A[i].x/n),Bv=(ll)round(B[i].x/n),Cv=(ll)round(C[i].x/n);
            S[i]=((Av%p<<30)+(Bv%p<<15)+Cv)%p;
        }
    }
}

int n,m,p,F[1<<18],G[1<<18],S[1<<18];

int main()
{
    n=Getint(),m=Getint(),p=Getint();
    for(rint i=0;i<=n;++i)F[i]=Getint();
    for(rint i=0;i<=m;++i)G[i]=Getint();
    for(m=n+m,n=1;n<=m;n<<=1);
    Poly::Pre(n),Poly::MTT(n,p,F,G,S);
    for(rint i=0;i<=m;++i)Putint(S[i]),*Outp++=i==m?'\n':' ';
    return fwrite(Out,1,Outp-Out,stdout),0;
}

Code length 2.96KB

With time 2.59s

Memory 80.93MB

Max Case 344ms

Emm than the above NTT also a lot faster, I could write NTT explode? (Update: go flipped through other people's Code, I write is what)

This recommendation in the examination room, easy to understand, the disadvantage is that large memory consumption, and low accuracy is requiredlong double

Tips:std::cos than coshigher precision, other functions, too


Code (DFT optimized version):

( \ (5 \) times DFT)

In fact, I do not understand how merging the IDFT. . .

Why all the code online and I are not the same?

Example: Luogu P4245 [template] NTT any modulus

//Luogu O2
#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
#define rint register int
typedef long long ll;
typedef long double ld;

//Having A Daydream...

char In[1<<20],*p1=In,*p2=In,Ch;
#define Getchar (p1==p2&&(p2=(p1=In)+fread(In,1,1<<20,stdin),p1==p2)?EOF:*p1++)
inline int Getint()
{
    register int x=0;
    while(!isdigit(Ch=Getchar));
    for(;isdigit(Ch);Ch=Getchar)x=x*10+(Ch^48);
    return x;
}

char Out[22222222],*Outp=Out,St[22],*Tp=St;
inline void Putint(int x)
{
    do *Tp++=x%10^48;while(x/=10);
    do *Outp++=*--Tp;while(St!=Tp);
}

const double Eps=1e-8,Pi=std::acos(-1),e=std::exp(1);
struct Complex
{
    ld x,y;
    inline Complex operator+(const Complex &o)const{return (Complex){x+o.x,y+o.y};}
    inline Complex operator-(const Complex &o)const{return (Complex){x-o.x,y-o.y};}
    inline Complex operator*(const Complex &o)const{return (Complex){x*o.x-y*o.y,x*o.y+y*o.x};}
    inline Complex operator/(const ld k)const{return (Complex){x/k,y/k};}
    inline Complex Conj(){return (Complex){x,-y};}
}Ome[1<<18],Inv[1<<18],I=(Complex){0,1};

int r[1<<18];
namespace Poly
{
    void Pre(int n)
    {
        for(rint i=0,l=(int)log2(n);i<n;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
        for(rint i=0;i<n;++i)
        {
            ld x=std::cos(2*Pi*i/n),y=std::sin(2*Pi*i/n);
            Ome[i]=(Complex){x,y},Inv[i]=(Complex){x,-y};
        }
    }

    void FFT(int n,Complex *A,Complex *T)
    {
        for(rint i=0;i<n;++i)if(i<r[i])std::swap(A[i],A[r[i]]);
        for(rint i=2;i<=n;i<<=1)
            for(rint j=0,h=i>>1;j<n;j+=i)
                for(rint k=0;k<h;++k)
                {
                    Complex Tmp=A[j+h+k]*T[n/i*k];
                    A[j+h+k]=A[j+k]-Tmp,A[j+k]=A[j+k]+Tmp;
                }
    }

    Complex P[1<<18],Q[1<<18];
    void Double_DFT(int n,Complex *A,Complex *B,Complex *T)
    {
        for(rint i=0;i<n;++i)P[i]=A[i]+B[i]*I,Q[i]=A[i]-B[i]*I;
        FFT(n,P,T);
        for(rint i=0;i<n;++i)Q[i]=(i?P[n-i]:P[0]).Conj();
        for(rint i=0;i<n;++i)A[i]=(P[i]+Q[i])/2,B[i]=(P[i]-Q[i])*I/-2;
    }

    Complex A1[1<<18],B1[1<<18],A2[1<<18],B2[1<<18];
    Complex A[1<<18],B[1<<18],C[1<<18];
    void MTT(int n,int p,int *F,int *G,int *S)
    {
        //这里为了方便直接设M=2^15=32768
        for(rint i=0;i<n;++i)
        {
            A1[i].x=F[i]>>15,B1[i].x=F[i]&0x7FFF;
            A2[i].x=G[i]>>15,B2[i].x=G[i]&0x7FFF;
        }
        //FFT(n,A1,Ome),FFT(n,B1,Ome),FFT(n,A2,Ome),FFT(n,B2,Ome);
        Double_DFT(n,A1,B1,Ome),Double_DFT(n,A2,B2,Ome);
        for(rint i=0;i<n;++i)
        {
            A[i]=A1[i]*A2[i];
            B[i]=A1[i]*B2[i]+A2[i]*B1[i];
            C[i]=B1[i]*B2[i];
        }
        FFT(n,A,Inv),FFT(n,B,Inv),FFT(n,C,Inv);
        //Double_DFT(n,A,B,Inv),FFT(n,C,Inv);//IDFT怎么合并?
        for(rint i=0;i<n;++i)
        {
            ll Av=(ll)round(A[i].x/n),Bv=(ll)round(B[i].x/n),Cv=(ll)round(C[i].x/n);
            S[i]=((Av%p<<30)+(Bv%p<<15)+Cv)%p;
        }
    }
}

int n,m,p,F[1<<18],G[1<<18],S[1<<18];

int main()
{
    n=Getint(),m=Getint(),p=Getint();
    for(rint i=0;i<=n;++i)F[i]=Getint();
    for(rint i=0;i<=m;++i)G[i]=Getint();
    for(m=n+m,n=1;n<=m;n<<=1);
    Poly::Pre(n),Poly::MTT(n,p,F,G,S);
    for(rint i=0;i<=m;++i)Putint(S[i]),*Outp++=i==m?'\n':' ';
    return fwrite(Out,1,Outp-Out,stdout),0;
}

Code length 3.41KB

With time 2.21s

Memory 94.60MB

Max Case 282ms

Optimization is not great was it. .


to sum up

In fact, MTT is not hard. Just a trick?

Tcl I just do not understand, then went with \ (7 \) times DFT it. .

References:

2016 National Team thesis "Revisited fast Fourier transform" - Mao Xiao (MYY, matthew99 )

Guess you like

Origin www.cnblogs.com/LanrTabe/p/11314179.html