[模板] 多项式: 乘法/求逆/分治fft/微积分/ln/exp/幂

多项式

代码

const int nsz=(int)4e5+50;
const ll nmod=998244353,g=3,ginv=332748118ll;


//basic math

ll qp(ll a,ll b){
    ll res=1;
    for(;b;a=a*a%nmod,b>>=1)if(b&1)res=res*a%nmod;
    return res;
}
ll inv(ll n){
    return qp(n,nmod-2);
}


//polynomial operations
//ntt version

namespace npoly{
    //the l means length of array, which means the polynomial has degree l-1
    ////^ for simplifying the doubling process in dft

    typedef int tpoly[nsz];

    tpoly a,b,ans;
    void cl(int *a,int l,int r){memset(a+l,0,(r-l+1)<<2);}
    void cp(int *a,int l,int r,int *b){memcpy(b,a+l,(r-l+1)<<2);}
    
    int len,rev[nsz];
    void fftinit(int l0){
        int l=0;
        while((1<<l)<l0)++l;
        len=(1<<l);
        rep(i,0,len-1){
            rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
        }
    }

    void dft(int *a,int len,int fl){
        rep(i,0,len-1)if(i<rev[i])swap(a[i],a[rev[i]]);
        for(int i=1;i<len;i<<=1){
            ll wn=qp((fl==1?g:ginv),(nmod-1)/(i<<1));
            for(int j=0,p=(i<<1);j<len;j+=p){
                ll w=1;
                for(int k=0;k<i;++k,w=w*wn%nmod){
                    ll x=a[j+k],y=a[j+k+i]*w%nmod;
                    a[j+k]=(x+y)%nmod,a[j+k+i]=(x-y+nmod)%nmod;
                }
            }
        }
        if(fl==-1){
            ll v=inv(len);
            rep(i,0,len-1)a[i]=a[i]*v%nmod;
        }
    }
    
    void mul(int *a,int l1,int *b,int l2,int *c){//c=a*b
        static int c1[nsz],c2[nsz];
        fftinit(l1+l2-1);
        cp(a,0,l1-1,c1);
        cl(c1,l1,len-1);
        cp(b,0,l2-1,c2),cl(c2,l2,len-1);

        dft(c1,len,1),dft(c2,len,1);
        rep(i,0,len-1)c1[i]=(ll)c1[i]*c2[i]%nmod;
        dft(c1,len,-1);

        cp(c1,0,l1+l2-2,c);
    }


    void _inverse(int *a,int l0,int *b){
        static int c1[nsz],c2[nsz];
        if(l0==1){b[0]=inv(a[0]);return;}
        _inverse(a,l0>>1,b);
        fftinit(l0<<1);//需要两倍长度dft保证乘法正确
        cp(a,0,l0-1,c1),cl(c1,l0,len-1);
        cp(b,0,(l0>>1)-1,c2),cl(c2,(l0>>1),len-1);

        dft(c1,len,1),dft(c2,len,1);
        rep(i,0,len-1)c2[i]=(ll)c2[i]*(2-((ll)c1[i]*c2[i])%nmod+nmod)%nmod;
        dft(c2,len,-1);

        cp(c2,0,l0-1,b);
    }
    bool inverse(int *a,int l0,int *b){//1 succeed;  0 fail
        if(a[0]==0)return 0;
        static int c1[nsz];
        int l1=1;
        while(l1<l0)l1<<=1;
        cp(a,0,l0-1,c1),cl(c1,l0,l1-1);
        _inverse(c1,l1,b);
        cl(b,l0,l1-1);
        return 1;
    }
    void dncfft(int *a,int l0,int *b){
        static int c1[nsz];
        rep(i,0,l0-1)c1[i]=nmod-a[i];//a[i]<nmod
        c1[0]=(c1[0]+1)%nmod;
        inverse(c1,l0,b);
    }

    void derivative(int *a,int l0,int *b){//b could = a
        rep(i,1,l0-1)b[i-1]=(ll)a[i]*i%nmod;
        b[l0-1]=0;
    }
    void integrate(int *a,int l0,int *b){
        repdo(i,l0-2,0)b[i+1]=(ll)inv(i+1)*a[i]%nmod;
        b[0]=0;
    }
    bool ln(int *a,int l0,int *b){//1 succeed;  0 fail
        if(a[0]==0)return 0;
        static int c1[nsz],c2[nsz];
        derivative(a,l0,c1);
        inverse(a,l0,c2);
        mul(c1,l0,c2,l0,b);
        integrate(b,l0,b);
        return 1;
    }
    void _exp(int *a,int l0,int *b){
        static int c1[nsz];
        if(l0==1){b[0]=1;return;}
        _exp(a,l0>>1,b);
        ln(b,l0,c1);
        rep(i,0,l0)c1[i]=(a[i]-c1[i]+nmod)%nmod;
        c1[0]=(c1[0]+1)%nmod;
        mul(c1,l0,b,l0,b);
    }
    bool exp(int *a,int l0,int *b){//1 succeed;  0 fail
        if(a[0])return 0;
        static int c1[nsz];
        int l1=1;
        while(l1<l0)l1<<=1;
        rep(i,0,l1)c1[i]=(i<l0?a[i]:0);
        _exp(c1,l1,b);//a[l0..(l1-1)] should be 0
        return 1;
    }
    void pow(int *a,int l0,int k,int *b){//suppose a[0]!=0;
        static int c1[nsz];
        ln(a,l0,c1);
        rep(i,0,l0-1)c1[i]=(ll)c1[i]*k%nmod;
        exp(c1,l0,b);
    }

    void sq(int *a,int l0,int *b){
        pow(a,l0,inv(2),b);
    }

//tests
    void testfft(){
        int n,m;
        cin>>n>>m,++n,++m;
        rep(i,0,n-1)cin>>a[i];
        rep(i,0,m-1)cin>>b[i];
        mul(a,n,b,m,ans);
        rep(i,0,n+m-2)cout<<ans[i]<<' ';
        cout<<'\n';
    }

    void testinv(){
        int n;
        cin>>n;
        rep(i,0,n-1)cin>>a[i];
        inverse(a,n,ans);
        rep(i,0,n-1)cout<<ans[i]<<' ';
        cout<<'\n';
    }
    void testdnc(){
        int n;
        cin>>n;
        rep(i,1,n-1)cin>>a[i];
        dncfft(a,n,ans);
        rep(i,0,n-1)cout<<ans[i]<<' ';
        cout<<'\n';
    }
    void testln(){
        int n;
        cin>>n;
        rep(i,0,n-1)cin>>a[i];
        ln(a,n,ans);
        rep(i,0,n-1)cout<<ans[i]<<' ';
        cout<<'\n';
    }
    void testexp(){
        int n;
        cin>>n;
        rep(i,0,n-1)cin>>a[i];
        exp(a,n,ans);
        rep(i,0,n-1)cout<<ans[i]<<' ';
        cout<<'\n';
    }   
    void testsq(){
        int n;
        cin>>n;
        rep(i,0,n-1)cin>>a[i];
        sq(a,n,ans);
        rep(i,0,n-1)cout<<ans[i]<<' ';
        cout<<'\n';
    }
}

int n;
int main(){
    ios::sync_with_stdio(0),cin.tie(0);
//  npoly::testfft();
//  npoly::testinv();
//  npoly::testdnc();
//  npoly::testln();
//  npoly::testexp();
    npoly::testsq();
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/ubospica/p/10475184.html
今日推荐