多项式模板QAQ

版权声明:本文为一名蒟蒻的原创文章,大神转载的话顺便说个出处呗。 https://blog.csdn.net/cgh_Andy/article/details/79462760

失踪人口回归?
似乎抄了挺多天的多项式题目代码。。留几个模板吧。。
(都不知道是不是好模板。。
反正也是自己看

最长不超过80行我去

UOJ #34. 多项式乘法

#include <bits/stdc++.h>
#define me(a,x) memset(a,x,sizeof a)
using namespace std;
const int N=3e5+2,inf=1e9+7;
const double pi=acos(-1);
char O[1<<14],*S=O,*T=O;
#define gc (S==T&&(T=(S=O)+fread(O,1,1<<14,stdin),S==T)?-1:*S++)
inline int read(){
    int x=0,f=1; char ch=gc;
    while(ch<'0' || ch>'9'){if(ch=='-')f=-1; ch=gc;}
    while(ch>='0' && ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=gc;}
    return x*f;
}
struct P{
    double x,y;
    P(){x=y=0;}
    P(double a,double b){x=a,y=b;}
}a[N],b[N],c[N];
P operator+(P x,P y){return P(x.x+y.x,x.y+y.y);}
P operator-(P x,P y){return P(x.x-y.x,x.y-y.y);}
P operator*(P x,P y){return P(x.x*y.x-x.y*y.y,x.y*y.x+x.x*y.y);}
int id[N],an,bn,cn,n,ln;
void fft(P *s,int si){
    for(int i=1;i<n;++i) if(i<id[i]) swap(s[i],s[id[i]]);
    for(int i=1;i<n;i<<=1){
        P wn(cos(pi/i),si*sin(pi/i));
        for(int j=0;j<n;j+=i<<1){
            P e(1,0),*b=s+j,*c=b+i;
            for(int k=0;k<i;++k,e=e*wn){
                P x=b[k],y=c[k]*e;
                b[k]=x+y,c[k]=x-y;
            }
        }
    }
    if(si<0) for(int i=0;i<n;++i) s[i].x/=n;
}
int main(){
    an=read()+1,bn=read()+1;
    for(int i=0;i<an;++i) a[i].x=read();
    for(int i=0;i<bn;++i) b[i].x=read();
    n=1,ln=0; while(n<an+bn) n<<=1,++ln;
    for(int i=0;i<n;++i) id[i]=id[i>>1]>>1 | ((i&1)<<(ln-1));
    fft(a,1); fft(b,1);
    for(int i=0;i<n;++i) c[i]=a[i]*b[i];
    fft(c,-1);
    printf("%d",int(c[0].x+0.5));
    for(int i=1;i<an+bn-1;++i) printf(" %d",int(c[i].x+0.5));
    puts("");
    return 0;
}

BZOJ 3992 原根的应用(还是生成函数?)+NTT

#include<bits/stdc++.h>
using namespace std;
const int N=16385,Mod=1004535809;
char O[1<<14],*S=O,*T=O;
#define gc (S==T&&(T=(S=O)+fread(O,1,1<<14,stdin),S==T)?-1:*S++)
inline int read(){
    int x=0,f=1; char ch=gc;
    while(ch<'0' || ch>'9'){if(ch=='-')f=-1; ch=gc;}
    while(ch>='0' && ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=gc;}
    return x*f;
}
int pw(int x,int k,int mod){
    int r=1;
    for(;k;k>>=1,x=1ll*x*x%mod) if(k&1)r=1ll*r*x%mod;
    return r;
}
int id[N],an,n,m,ln,a[N],b[N],ans[N],c[N/2],v[N/2],ti,ny;
void ntt(int *s,int si){
    for(int i=1;i<n;++i) if(i<id[i]) swap(s[i],s[id[i]]);
    for(int i=1;i<n;i<<=1){
        int wn=pw(3,si==1?(Mod-1)/i/2:Mod-1-(Mod-1)/i/2,Mod);
        for(int j=0;j<n;j+=i<<1){
            int e=1,*b=s+j,*c=b+i;
            for(int k=0;k<i;++k,e=1ll*e*wn%Mod){
                int x=b[k],y=1ll*c[k]*e%Mod;
                b[k]=(x+y)%Mod,c[k]=(x-y)%Mod;
            }
        }
    }
    if(si<0) for(int i=0;i<n;++i) s[i]=1ll*s[i]*ny%Mod;
}
void mul(int *a,int *bb){
    for(int i=0;i<n;++i) b[i]=bb[i];
    ntt(a,1),ntt(b,1);
    for(int i=0;i<n;++i) a[i]=1ll*a[i]*b[i]%Mod;
    ntt(a,-1);
    for(int i=m-1;i<n;++i) a[i-m+1]=(a[i-m+1]+a[i])%Mod,a[i]=0;
}
bool check(int x,int m){
    int u=1; ++ti;
    for(int i=1;i<m;++i,u=u*x%m){
        if(v[u]==ti)return 0; v[u]=ti;
    }
    return 1;
}
int get(int m){for(int i=2;i<=m;++i) if(check(i,m)) return i;}
int main(){
    an=read(),m=read(); int x=read(),s=read(),g=get(m);
    for(int i=1,w=g;i<m-1;++i,w=w*g%m) c[w]=i;
    for(int i=1;i<=s;++i){
        int x=read(); if(!x)continue;
        a[c[x]]=1;
    }
    n=1,ln=0; while(n<m+m) n<<=1,++ln;
    for(int i=0;i<n;++i) id[i]=id[i>>1]>>1 | ((i&1)<<ln-1);
    ny=pw(n,Mod-2,Mod); ans[0]=1;
    for(;an;an>>=1){
        if(an&1) mul(ans,a);
        mul(a,a);
    }
    printf("%d\n",(ans[c[x]]+Mod)%Mod);
    return 0;
}

BZOJ3456 多项式求逆

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=262145,mo=1004535809;
char O[1<<14],*S=O,*T=O;
#define gc (S==T&&(T=(S=O)+fread(O,1,1<<14,stdin),S==T)?-1:*S++)
inline int read(){
    int x=0,f=1; char ch=gc;
    while(ch<'0' || ch>'9'){if(ch=='-')f=-1; ch=gc;}
    while(ch>='0' && ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=gc;}
    return x*f;
}
int pw(int x,int k){
    int r=1;
    for(;k;k>>=1,x=1ll*x*x%mo) if(k&1)r=1ll*r*x%mo;
    return r;
}
int id[N],an,ln,c[N],g[N],f[N],t[N],jc[N],ny[N];
void ntt(int *s,int n,int si){
    for(int i=1;i<n;++i) if(i<id[i]) swap(s[i],s[id[i]]);
    for(int i=1;i<n;i<<=1){
        int wn=pw(3,si==1?(mo-1)/i/2:mo-1-(mo-1)/i/2);
        for(int j=0;j<n;j+=i<<1){
            int e=1,*b=s+j,*c=b+i;
            for(int k=0;k<i;++k,e=1ll*e*wn%mo){
                int x=b[k],y=1ll*c[k]*e%mo;
                b[k]=(x+y)%mo,c[k]=(x-y+mo)%mo;
            }
        }
    }
    int ny=pw(n,mo-2);
    if(si<0) for(int i=0;i<n;++i) s[i]=1ll*s[i]*ny%mo;
}
void pre(const int n,const int ln){
    for(int i=0;i<n;++i) id[i]=id[i>>1]>>1 | ((i&1)<<ln-1);
}
void get_inv(int *a,int *b,const int u,const int ln){
    if(u==1){
        b[0]=pw(a[0],mo-2); return;
    }
    get_inv(a,b,u>>1,ln-1);
    pre(u<<1,ln+1);
    for(int i=0;i<u;++i) t[i]=a[i],t[i+u]=0;
    ntt(t,u<<1,1),ntt(b,u<<1,1);
    for(int i=0;i<(u<<1);++i) t[i]=(2ll-(LL)t[i]*b[i]%mo)*b[i]%mo;
    ntt(t,u<<1,-1);
    for(int i=0;i<u;++i) b[i]=t[i],b[i+u]=0;
}
int main(){
    an=read(); jc[0]=ny[0]=1; int n,i;
    for(i=1;i<=an;++i) jc[i]=1ll*jc[i-1]*i%mo,ny[i]=pw(jc[i],mo-2);
    for(i=0;i<=an;++i) g[i]=1ll*pw(2,1ll*i*(i-1)/2%(mo-1))*ny[i]%mo;
    for(i=1;i<=an;++i) c[i]=1ll*pw(2,1ll*i*(i-1)/2%(mo-1))*ny[i-1]%mo;
    n=1,ln=0; while(n<=an) n<<=1,++ln;
    get_inv(g,f,n,ln);
    ntt(c,n<<1,1); ntt(f,n<<1,1);
    for(i=0;i<(n<<1);++i) f[i]=1ll*f[i]*c[i]%mo;
    ntt(f,n<<1,-1);
    printf("%d\n",(1ll*f[an]*jc[an-1]%mo+mo)%mo);
    return 0;
}

51nod 1348 CRT+NTT(还有分治?

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=131073,mo1=998244353,mo2=1004535809,mo=100003;
char O[1<<14],*S=O,*T=O;
#define gc (S==T&&(T=(S=O)+fread(O,1,1<<14,stdin),S==T)?-1:*S++)
inline int read(){
    int x=0,f=1; char ch=gc;
    while(ch<'0' || ch>'9'){if(ch=='-')f=-1; ch=gc;}
    while(ch>='0' && ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=gc;}
    return x*f;
}
int A[N],B[N],a[18][N],b1[18][N],b2[18][N];
int id[N],an,o[N],mod;
int pw(int x,int k,int p){
    int r=1;
    for(;k;k>>=1,x=1ll*x*x%p) if(k&1) r=1ll*r*x%p;
    return r;
}
void ntt(int *s,int n,int si){
    for(int i=1;i<n;++i) if(i<id[i]) swap(s[i],s[id[i]]);
    for(int i=1;i<n;i<<=1){
        int wn=pw(3,si==1?(mod-1)/2/i:mod-1-(mod-1)/2/i,mod);
        for(int j=0;j<n;j+=i<<1){
            int e=1,*b=s+j,*c=b+i;
            for(int k=0;k<i;++k,e=1ll*e*wn%mod){
                int x=b[k],y=1ll*c[k]*e%mod;
                b[k]=(x+y)%mod,c[k]=(x-y)%mod;
            }
        }
    }
    if(si<0) for(int ny=pw(n,mod-2,mod),i=0;i<n;++i) s[i]=1ll*s[i]*ny%mod;
}
LL mul(LL x,LL y,LL m){
    LL tmp=(x*y-(LL)((double)x*y/m+1e-8)*m)%m;
    return tmp<0?tmp+m:tmp;
}
LL merge(int m1,int m2){
    LL m=1ll*mo1*mo2;
    return ( mul((LL)mo2*pw(mo2,mo1-2,mo1),m1,m) + mul((LL)mo1*pw(mo1,mo2-2,mo2),m2,m) )%m;
}
void solve(int l,int r,int d){
    if(l==r){a[d][0]=1,a[d][1]=o[l]; return;}
    int mid=l+r>>1,m=r-l+1,ln=0,n=1;
    while(n<=m+m) n<<=1,++ln;
    solve(l,mid,d+1);
    for(int i=0;i<=mid-l+1;++i) b1[d][i]=a[d+1][i],a[d+1][i]=0;
    solve(mid+1,r,d+1);
    for(int i=0;i<=r-mid;++i) b2[d][i]=a[d+1][i],a[d+1][i]=0;
    for(int i=0;i<n;++i) id[i]=id[i>>1]>>1 | ((i&1)<<ln-1);
    mod=mo1;
    for(int i=0;i<=mid-l+1;++i) A[i]=b1[d][i];
    for(int i=0;i<=r-mid;++i) B[i]=b2[d][i];
    ntt(A,n,1); ntt(B,n,1);
    for(int i=0;i<n;++i) A[i]=1ll*A[i]*B[i]%mod,B[i]=0;
    ntt(A,n,-1);
    for(int i=0;i<=m;++i) a[d][i]=A[i],A[i]=0;

    mod=mo2;
    for(int i=0;i<=mid-l+1;++i) A[i]=b1[d][i];
    for(int i=0;i<=r-mid;++i) B[i]=b2[d][i];
    ntt(A,n,1); ntt(B,n,1);
    for(int i=0;i<n;++i) A[i]=1ll*A[i]*B[i]%mod,B[i]=0;
    ntt(A,n,-1);
    for(int i=0;i<=m;++i) a[d][i]=merge(a[d][i],A[i])%mo,A[i]=0;
}
int main(){
    an=read(); int q=read();
    for(int i=1;i<=an;++i) o[i]=read()%mo;
    solve(1,an,0);
    while(q--) printf("%d\n",(a[0][read()]+mo)%mo); 
    return 0;
}

51nod 1172 任意模数fft 打了mtt

#include<bits/stdc++.h>
using namespace std;
typedef long double ld;
typedef long long LL;
const int N=131073,mo=1e9+7;
const ld pi=acos(-1);
char O[1<<14],*S=O,*T=O;
#define gc (S==T&&(T=(S=O)+fread(O,1,1<<14,stdin),S==T)?-1:*S++)
inline int read(){
    int x=0,f=1; char ch=gc;
    while(ch<'0' || ch>'9'){if(ch=='-')f=-1; ch=gc;}
    while(ch>='0' && ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=gc;}
    return x*f;
}
struct P{
    ld x,y;
    P(ld a=0,ld b=0){x=a,y=b;}
    inline P con(){return P(x,-y);}
}A[N],B[N],dfa[N],dfb[N],dfc[N],dfd[N];
P operator+(P x,P y){return P(x.x+y.x,x.y+y.y);}
P operator-(P x,P y){return P(x.x-y.x,x.y-y.y);}
P operator*(P x,P y){return P(x.x*y.x-x.y*y.y,x.y*y.x+x.x*y.y);}
int id[N],an,n,ln,ny[N>>1],jc[N>>1],a[N],c[N];
inline void fft(P *s,int si){
    for(int i=1;i<n;++i) if(i<id[i]) swap(s[i],s[id[i]]);
    for(int i=1;i<n;i<<=1){
        P wn=P(cos(pi/i),si*sin(pi/i));
        for(int j=0;j<n;j+=i<<1){
            P e=P(1,0),*b=s+j,*c=b+i;
            for(int k=0;k<i;++k,e=e*wn){
                P x=b[k],y=c[k]*e;
                b[k]=x+y,c[k]=x-y;
            }
        }
    }
    //if(si<0) for(int i=0;i<n;++i) s[i].x/=n;
}
inline void mul(int *x,int *y){
    for(int i=0;i<an;++i)
        A[i]=P(x[i]&32767,x[i]>>15),B[i]=P(y[i]&32767,y[i]>>15);
    fft(A,1); fft(B,1);
    for(int i=0;i<n;++i){
        int j=n-i & n-1;
        P p=(A[i]+A[j].con())*P(0.5,0),q=(A[i]-A[j].con())*P(0,-0.5);
        P r=(B[i]+B[j].con())*P(0.5,0),s=(B[i]-B[j].con())*P(0,-0.5);
        dfa[i]=p*r,dfb[i]=p*s,dfc[i]=q*r,dfd[i]=q*s;
    }
    for(int i=0;i<n;++i)
        A[i]=dfa[i]+dfb[i]*P(0,1),B[i]=dfc[i]+dfd[i]*P(0,1);
    fft(A,-1); fft(B,-1);
    for(int i=0;i<an;++i){
        int p=(LL)(A[i].x/n+0.5)%mo,q=(LL)(A[i].y/n+0.5)%mo,r=(LL)(B[i].x/n+0.5)%mo,s=(LL)(B[i].y/n+0.5)%mo;
        printf("%d\n",( ((LL)s<<30)+((LL)(q+r)<<15)+p )%mo);
    }
}
int main(){
    an=read(); int k=read(),i;
    ny[0]=ny[1]=c[0]=1,c[1]=k;
    for(i=2;i<an;++i) ny[i]=1ll*(mo-mo/i)*ny[mo%i]%mo,c[i]=1ll*c[i-1]*(k+i-1)%mo;
    for(i=2;i<an;++i) ny[i]=1ll*ny[i]*ny[i-1]%mo,c[i]=1ll*c[i]*ny[i]%mo;
    for(i=0;i<an;++i) a[i]=read();
    n=1,ln=0; while(n<=an+an) n<<=1,++ln;
    for(i=0;i<n;++i) id[i]=id[i>>1]>>1 | ((i&1)<<ln-1);
    mul(a,c);
    return 0;
}

bzoj 3625 多项式开根

#include<bits/stdc++.h>
using namespace std;
const int N=262145,mo=998244353,n2=499122177;
char O[1<<14],*S=O,*T=O;
#define gc (S==T&&(T=(S=O)+fread(O,1,1<<14,stdin),S==T)?-1:*S++)
inline int read(){
    int x=0,f=1; char ch=gc;
    while(ch<'0' || ch>'9'){if(ch=='-')f=-1; ch=gc;}
    while(ch>='0' && ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=gc;}
    return x*f;
}
int id[N],c[N],d[N],a[N],b[N];
int pw(int x,int k){
    int r=1;
    for(;k;k>>=1,x=1ll*x*x%mo) if(k&1) r=1ll*r*x%mo;
    return r;
}
void ntt(int *s,int n,int si){
    for(int i=1;i<n;++i) if(i<id[i]) swap(s[i],s[id[i]]);
    for(int i=1;i<n;i<<=1){
        int wn=pw(3,si==1?(mo-1)/2/i:mo-1-(mo-1)/2/i);
        for(int j=0;j<n;j+=i<<1){
            int e=1,*b=s+j,*c=b+i;
            for(int k=0;k<i;++k,e=1ll*e*wn%mo){
                int x=b[k],y=1ll*c[k]*e%mo;
                b[k]=(x+y)%mo,c[k]=(x-y)%mo;
            }
        }
    }
    if(si<0) for(int ny=pw(n,mo-2),i=0;i<n;++i) s[i]=1ll*s[i]*ny%mo;
}
void inv(int *a,int *b,int n,int ln){
    if(n==1) return void(b[0]=pw(a[0],mo-2));
    inv(a,b,n>>1,ln-1);
    for(int i=0;i<n;++i) c[i]=a[i],c[i+n]=0;
    for(int i=0;i< n<<1;++i) id[i]=id[i>>1]>>1|((i&1)<<ln);
    ntt(c,n<<1,1); ntt(b,n<<1,1);
    for(int i=0;i< n<<1;++i) b[i]=1ll*b[i]*(2-1ll*c[i]*b[i]%mo)%mo;
    ntt(b,n<<1,-1);
    memset(b+n,0,n*sizeof(int));
}
void Sqrt(int *a,int *b,int n,int ln){
    if(n==1) return void(b[0]=1);
    Sqrt(a,b,n>>1,ln-1);
    memset(d,0,n*2*sizeof(int));
    inv(b,d,n,ln);
    for(int i=0;i<n;++i) c[i]=a[i],c[i+n]=0;
    ntt(c,n<<1,1); ntt(b,n<<1,1); ntt(d,n<<1,1);
    for(int i=0;i< n<<1;++i) b[i]=(1ll*c[i]*d[i]%mo+b[i])%mo*n2%mo;
    ntt(b,n<<1,-1);
    memset(b+n,0,n*sizeof(int));
}
int main(){
    int n=read(),m=read(); a[0]=1;
    for(int i=1;i<=n;++i){
        int x=read(); if(x<=m)a[x]=mo-4;
    }
    int ln=0; n=1; for(;n<=m;++ln,n<<=1);
    Sqrt(a,b,n,ln); b[0]=(b[0]+1)%mo;
    memset(a,0,n*sizeof(int)); inv(b,a,n,ln);
    for(int i=1;i<=m;++i) printf("%d\n",((a[i]<<1)%mo+mo)%mo);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/cgh_Andy/article/details/79462760
今日推荐