【洛谷P5158】 【模板】多项式快速插值

卡常严重,可有采用如下优化方案: 

1.预处理单位根 

2.少取几次模 

3.复制数组时用 memcpy     

4.进行多项式乘法项数少的时候直接暴力乘 

5.进行多项式多点求值时如果项数小于500的话直接秦九昭展开

code: 

#include <bits/stdc++.h>     
#define ll long long 
#define ull unsigned long long 
#define setIO(s) freopen(s".in","r",stdin) // , freopen(s".out","w",stdout)        
using namespace std;  
char buf[100000],*p1,*p2;
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
int rd() 
{
    int x=0; char s=nc();
    while(s<'0') s=nc();
    while(s>='0') x=(((x<<2)+x)<<1)+s-'0',s=nc();
    return x;
}         
void print(int x) {if(x>=10) print(x/10);putchar(x%10+'0');}
const int G=3;  
const int N=2000005;   
const int mod=998244353;                   
int A[N],B[N],w[2][N],mem[N*100],*ptr=mem;      
inline int qpow(int x,int y) 
{
    int tmp=1;     
    for(;y;y>>=1,x=(ll)x*x%mod)     if(y&1) tmp=(ll)tmp*x%mod;  
    return tmp;    
}      
inline int INV(int a) { return qpow(a,mod-2); }        
inline void ntt_init(int len) 
{
    int i,j,k,mid,x,y;      
    w[1][0]=w[0][0]=1,x=qpow(3,(mod-1)/len),y=qpow(x,mod-2);
    for (i=1;i<len;++i) w[0][i]=(ll)w[0][i-1]*x%mod,w[1][i]=(ll)w[1][i-1]*y%mod;         
}
void NTT(int *a,int len,int flag) 
{
    int i,j,k,mid,x,y;                
    for(i=k=0;i<len;++i) 
    {
        if(i>k)    swap(a[i],a[k]);  
        for(j=len>>1;(k^=j)<j;j>>=1);  
    }   
    for(mid=1;mid<len;mid<<=1)            
        for(i=0;i<len;i+=mid<<1) 
            for(j=0;j<mid;++j)          
            {
                x=a[i+j], y=(ll)w[flag==-1][len/(mid<<1)*j]*a[i+j+mid]%mod;  
                a[i+j]=(x+y)%mod;  
                a[i+j+mid]=(x-y+mod)%mod;   
            }   
    if(flag==-1)  
    {
        int rev=INV(len);   
        for(i=0;i<len;++i)    a[i]=(ll)a[i]*rev%mod;   
    }
}              
inline void getinv(int *a,int *b,int len,int la) 
{
    if(len==1) { b[0]=INV(a[0]);   return; }
    getinv(a,b,len>>1,la);    
    int l=len<<1,i;   
    memset(A,0,l*sizeof(A[0]));              
    memset(B,0,l*sizeof(A[0]));    
    memcpy(A,a,min(la,len)*sizeof(a[0]));                                                      
    memcpy(B,b,len*sizeof(b[0]));             
    ntt_init(l);   
    NTT(A,l,1),NTT(B,l,1);      
    for(i=0;i<l;++i)  A[i]=((ll)2-(ll)A[i]*B[i]%mod+mod)*B[i]%mod;
    NTT(A,l,-1);                                 
    memcpy(b,A,len<<2);          
}  
struct poly 
{
    int len,*a;    
    poly(){}       
    poly(int l) {len=l,a=ptr,ptr+=l; }            
    inline void rev() { reverse(a,a+len); }       
    inline void fix(int l) {len=l,a=ptr,ptr+=l;}   
    inline void get_mod(int l) { for(int i=l;i<len;++i) a[i]=0;  len=l;  }
    inline poly dao() 
    {        
        poly re(len-1);   
        for(int i=1;i<len;++i)  re.a[i-1]=(ll)i*a[i]%mod;         
        return re;    
    }    
    inline poly Inv(int l) 
    {  
        poly b(l);              
        getinv(a,b.a,l,len);                                  
        return b;                        
    }                                                                    
    inline poly operator * (const poly &b) const 
    {
        poly c(len+b.len-1);   
        if(c.len<=500) 
        {         
            for(int i=0;i<len;++i)   
                if(a[i])   for(int j=0;j<b.len;++j)  c.a[i+j]=(c.a[i+j]+(ll)(a[i])*b.a[j])%mod;      
            return c; 
        }
        int n=1;    
        while(n<(len+b.len)) n<<=1; 
        memset(A,0,n<<2);  
        memset(B,0,n<<2);   
        memcpy(A,a,len<<2);                             
        memcpy(B,b.a,b.len<<2);            
        ntt_init(n);        
        NTT(A,n,1), NTT(B,n,1);     
        for(int i=0;i<n;++i) A[i]=(ll)A[i]*B[i]%mod;   
        NTT(A,n,-1);   
        memcpy(c.a,A,c.len<<2);  
        return c;       
    }    
    poly operator + (const poly &b) const 
    {
        poly c(max(len,b.len));    
        for(int i=0;i<c.len;++i)  c.a[i]=((i<len?a[i]:0)+(i<b.len?b.a[i]:0))%mod;   
        return c;    
    }
    poly operator - (const poly &b) const 
    {    
        poly c(len);       
        for(int i=0;i<len;++i)   
        {
            if(i>=b.len)   c.a[i]=a[i];  
            else c.a[i]=(a[i]-b.a[i]+mod)%mod;    
        } 
        return c;  
    }
    poly operator / (poly u) 
    {  
        int n=len,m=u.len,l=1;  
        while(l<(n-m+1)) l<<=1;                           
        rev(),u.rev();            
        poly v=u.Inv(l);    
        v.get_mod(n-m+1);        
        poly re=(*this)*v;   
        rev(),u.rev();    
        re.get_mod(n-m+1);         
        re.rev();  
        return re;   
    }      
    poly operator % (poly u) 
    {      
        poly re=(*this)-u*(*this/u);        
        re.get_mod(u.len-1);       
        return re;    
    }                     
}p[N<<2],pr;    
int xx[N],yy[N];              
#define lson now<<1  
#define rson now<<1|1           
inline void pushup(int l,int r,int now)
{
    int mid=(l+r)>>1;      
    if(r>mid)   p[now]=p[lson]*p[rson]; 
    else p[now]=p[lson];   
}
void build(int l,int r,int now,int *pp) 
{
    if(l==r) 
    {     
        p[now].fix(2);  
        p[now].a[0]=mod-pp[l];  
        p[now].a[1]=1;   
        return; 
    }  
    int mid=(l+r)>>1;   
    if(l<=mid)  build(l,mid,lson,pp);     
    if(r>mid)   build(mid+1,r,rson,pp);          
    p[now]=p[lson]*p[rson];   
}    
void get_val(int l,int r,int now,poly b,int *pp,int *t) 
{
    if(b.len<=500)     
    {   
        for(int i=l;i<=r;++i) 
        {
            ull s=0;             
            for(int j=b.len-1;j>=0;--j)     
            {
                s=((ull)s*pp[i]+b.a[j])%mod;  
                if(!(j&7))   s%=mod;       
            }
            t[i]=s%mod;   
        }
        return;  
    } 
    int mid=(l+r)>>1;     
    if(l<=mid)   get_val(l,mid,lson,b%p[lson],pp,t);  
    if(r>mid)    get_val(mid+1,r,rson,b%p[rson],pp,t);     
}   
poly solve_polate(int l,int r,int now,int *t) 
{
    if(l==r) 
    {
        poly re(1);   
        re.a[0]=t[l];   
        return re;   
    } 
    int mid=(l+r)>>1;    
    poly L,R;  
    L=solve_polate(l,mid,lson,t);   
    R=solve_polate(mid+1,r,rson,t);   
    return L*p[rson]+R*p[lson];           
}       
void check_Interpolate();  
poly Interpolate(int *a,int *b,int n);       
void check_Evaluation();  
void check_Inv();   
void check_mult();   
void check_divide();       
poly Interpolate(int *a,int *b,int n) 
{ 
    int i,j;   
    build(1,n,1,a);        
    static int t[N];  
    poly tmp=p[1].dao();          
    get_val(1,n,1,tmp,a,t);                            
    for(i=1;i<=n;++i)    t[i]=(ll)INV(t[i])*b[i]%mod;                       
    return solve_polate(1,n,1,t);    
}
void check_Interpolate() 
{
    // setIO("input");        
    int i,j,n; 
    n=rd(); 
    for(i=1;i<=n;++i)      xx[i]=rd(),yy[i]=rd(); 
    poly re=Interpolate(xx,yy,n);                 
    for(i=0;i<re.len;++i)       print(re.a[i]), printf(" "); 
    for(;i<n;++i)    print(re.a[i]), printf(" ");      
}
void check_Evaluation() 
{   
    int i,j,n,m,l; 
    n=rd(),m=rd();              
    pr.fix(n+1);         
    static int pp[N];    
    for(i=0;i<=n;++i)   pr.a[i]=rd();   
    for(i=1;i<=m;++i)   pp[i]=rd();   
    build(1,m,1,pp);   
    get_val(1,m,1,pr,pp,pp);                                         
    for(i=1;i<=m;++i)   printf("%d\n",pp[i]);                    
}
void check_Inv() 
{
    int i,j,n; 
    scanf("%d",&n);    
    pr.fix(n);   
    for(i=0;i<n;++i)   scanf("%d",&pr.a[i]);      
    int l=1; 
    while(l<n)  l<<=1;  
    pr=pr.Inv(l);   
    for(i=0;i<n;++i)   printf("%d ",pr.a[i]);      
}
void check_mult() 
{
    int i,j,n,m; 
    scanf("%d%d",&n,&m);  
    poly a(n+1),b(m+1);  
    for(i=0;i<=n;++i)   scanf("%d",&a.a[i]); 
    for(i=0;i<=m;++i)   scanf("%d",&b.a[i]); 
    a=a*b;  
    for(i=0;i<a.len;++i)   printf("%d ",a.a[i]); 
}
void check_divide() 
{
    int i,j,n,m;   
    scanf("%d%d",&n,&m);    
    poly F(n+1), G(m+1);   
    for(i=0;i<=n;++i)    scanf("%d",&F.a[i]);  
    for(i=0;i<=m;++i)    scanf("%d",&G.a[i]);  
    poly Q=F/G;  
    poly R=F%G;  
    for(i=0;i<Q.len;++i)    printf("%d ",Q.a[i]);    
    printf("\n");   
    for(i=0;i<R.len;++i)    printf("%d ",R.a[i]);   
}    

  

猜你喜欢

转载自www.cnblogs.com/guangheli/p/11928600.html
今日推荐