luoguP4512 【模板】多项式除法 NTT+多项式求逆+多项式除法

Code:

#include<bits/stdc++.h>
#define maxn 300000 
#define ll long long 
#define MOD 998244353 
#define setIO(s) freopen(s".in","r",stdin) ,freopen(s".out","w",stdout) 
using namespace std;        
namespace poly{
    #define P 998244352 
    #define G 3 
    int rev[maxn]; 
    ll X[maxn],Y[maxn];                                                  
    void calrev(int lim,int l){ for(int i=1;i<lim;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));  } 
    ll add(ll a,ll b){ return ((a+=b)%=MOD); } 
    ll qpow(ll a,ll k){
        ll base=1; 
        for(;k;a=(a*a)%MOD,k>>=1) if(k&1)base=(base*a)%MOD;  
        return base; 
    }
    void NTT(ll *a,int len,int opt){
        for(int i=0;i<len;++i) if(i<rev[i]) swap(a[i],a[rev[i]]); 
        for(int i=1;i<len;i<<=1){
            int step=i<<1; 
            ll wn=qpow(G,(opt*P/step+P)); 
            for(int j=0;j<len;j+=step){
                ll w=1; 
                for(int k=0;k<i;++k,w=(1ll*w*wn)%MOD){
                    ll x=a[j+k]; 
                    ll y=1ll*w*a[j+k+i]%MOD; 
                    a[j+k]=(x+y)%MOD; 
                    a[j+k+i]=(x-y+MOD)%MOD; 
                } 
            }
        }
        if(opt==-1){
            ll r=qpow(len,MOD-2); 
            for(int i=0;i<len;++i) a[i]=1ll*a[i]*r%MOD; 
        }
    }
    void mul(ll *x,ll *y,int lim){
        memset(X,0,sizeof(X)),memset(Y,0,sizeof(Y)); 
        for(int i=0;i<(lim>>1);++i) X[i]=x[i],Y[i]=y[i]; 
        NTT(X,lim,1),NTT(Y,lim,1);
        for(int i=0;i<lim;++i) X[i]=(ll)X[i]*Y[i]%MOD; 
        NTT(X,lim,-1); 
        for(int i=0;i<lim;++i) x[i]=X[i];  
    }
    ll B[3][maxn],C[maxn],D[maxn]; 
    void get_inv(int n,ll *A){
        int cur=0,bas=1,lim=2,len=1;
        B[cur][0]=qpow(A[0],MOD-2); 
        calrev(lim,len); 
        while(bas<=(n<<1)){
            cur^=1; 
            memset(B[cur],0,sizeof(B[cur])); 
            for(int i=0;i<bas;++i) B[cur][i]=add(B[cur^1][i]<<1,0); 
            mul(B[cur^1],B[cur^1],lim),mul(B[cur^1],A,lim); 
            for(int i=0;i<bas;++i) B[cur][i]=add(B[cur][i],MOD-B[cur^1][i]);     
            bas<<=1,lim<<=1,++len; 
            if(bas<=(n<<1)) calrev(lim,len); 
        }
        for(int i=0;i<=n;++i) A[i]=B[cur][i]; 
    }
};
ll A[maxn],B[maxn],n,m,lim,len; 
ll Ar[maxn],Br[maxn],Dr[maxn]; 
int main(){ 
    //setIO("input");   
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;++i) scanf("%d",&A[i]),Ar[i]=A[i]; 
    for(int i=0;i<=m;++i) scanf("%d",&B[i]),Br[i]=B[i]; 
    reverse(Ar,Ar+n+1),reverse(Br,Br+m+1);                          
    for(int i=n-m+2;i<=max(n,m);++i) Br[i]=Ar[i]=0;    
    poly::get_inv(n-m+1,Br);       //Br的逆   
    lim=1,len=0;    
    while(lim<=n-m+1+n-m+1) lim<<=1,++len; 
    poly::calrev(lim,len), poly::mul(Ar,Br,lim);
    for(int i=n-m;i>=0;--i) printf("%lld ",Ar[i]),Dr[n-m-i]=Ar[i];       

    lim=1,len=0;
    while(lim<=n*2) lim<<=1,++len; 
    poly::calrev(lim,len),poly::mul(B,Dr,lim);
    printf("\n");
    for(int i=0;i<=m-1;++i) {
        ll h=(A[i]-B[i]+MOD)%MOD;
        printf("%lld ",h); 
    } 
    return 0;  
} 

  

猜你喜欢

转载自www.cnblogs.com/guangheli/p/10739668.html