版权声明:本文为博主原创文章,未经博主允许必须转载。 https://blog.csdn.net/qq_35950004/article/details/85274386
1.核心:
FFT:
正常版本:
#include<bits/stdc++.h>
#define maxn 400005
using namespace std;
const double PI = acos(-1);
struct cplx
{
double r,i;
cplx(double r=0,double i=0):r(r),i(i){}
cplx operator +(const cplx &B)const{ return cplx(r+B.r,i+B.i); }
cplx operator -(const cplx &B)const{ return cplx(r-B.r,i-B.i); }
cplx operator *(const cplx &B)const{ return cplx(r*B.r-i*B.i,i*B.r+r*B.i); }
cplx conj(){ return cplx(r,-i); }
}a[maxn],b[maxn];
int r[maxn]={};
cplx w[maxn] = {1};
inline void FFT(cplx *A,int lgn,int tp)
{
int n = 1<<lgn;
for(int i=1;i<n;i++) r[i] = (r[i>>1]>>1) | ((i&1)<<(lgn-1));
for(int i=1;i<n;i++) if(i < r[i]) swap(A[i] , A[r[i]]);
for(int len=2;len<=n;len<<=1){
int l = len >> 1;cplx wn(cos(PI / l) , sin(PI / l) * tp);
for(int i=1;i<l;i++) w[i] = w[i-1] * wn;
for(int st = 0;st < n;st += len) for(int k=0;k<l;k++)
{
cplx tmp = w[k] * A[st + k + l];
A[st + k + l] = A[st + k] - tmp , A[st + k] = A[st + k] + tmp;
}
}
if(tp==-1) for(int i=0;i<n;i++) A[i].r /= n;
}
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lf",&a[i].r);
for(int i=0;i<=m;i++) scanf("%lf",&a[i].i);
n++,m++;
int len = 0;
for(;n+m>(1<<len);len++);
FFT(a,len,1);
for(int i=0,ci,Len = 1<<len;i<Len;i++)
{
ci = (Len - i) & (Len - 1);
cplx A = (a[i] + a[ci].conj())*cplx(0.5,0) , B = (a[i] - a[ci].conj())*cplx(0,-0.5);
b[i] = A * B;
}
FFT(b,len,-1);
for(int i=0;i<n+m-1;i++) printf("%d ",int(b[i].r+0.5));
}
预处理单位元(精度高):
#include<bits/stdc++.h>
#define maxn 300005
using namespace std;
const double Pi = 3.1415926535897932384626433832795;
struct cplx
{
double r,i;
cplx(double r=0,double i=0):r(r),i(i){}
cplx operator +(const cplx &B)const{ return cplx(r+B.r,i+B.i); }
cplx operator -(const cplx &B)const{ return cplx(r-B.r,i-B.i); }
cplx operator *(const cplx &B)const{ return cplx(r*B.r-i*B.i,i*B.r+r*B.i); }
cplx conj()const{ return cplx(r,-i); }
}w[maxn],A[maxn],B[maxn];
int r[maxn];
inline void FFT(cplx A[maxn],int lgn,int tp)
{
int n = 1<<lgn;
for(int i=0;i<n;i++) w[i]=cplx(cos(i*Pi/n),sin(i*Pi/n));
for(int i=0;i<n;i++) r[i] = (r[i>>1]>>1)|((i&1)<<(lgn-1));
for(int i=0;i<n;i++) if(i < r[i]) swap(A[i] , A[r[i]]);
for(int L=2;L<=n;L<<=1)
for(int st=0,l=L>>1;st<n;st+=L)
for(int k=0,lc=0,inc=n/l;k<l;k++,lc+=inc)
{
cplx tmp = (tp==1 ? w[lc] : w[lc].conj()) * A[st+k+l];
A[st+k+l]=A[st+k]-tmp,A[st+k]=A[st+k]+tmp;
}
if(tp==-1) for(int i=0;i<n;i++) A[i].r/=n,A[i].i/=n;
}
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lf",&A[i].r);
for(int i=0;i<=m;i++) scanf("%lf",&A[i].i);
int lgn=0;for(;n+m>=(1<<lgn);lgn++);
FFT(A,lgn,1);
for(int i=0,len=1<<lgn;i<len;i++)
{
cplx u=A[i],v=A[(len-1)&(len-i)].conj();
B[i]=(u+v)*(u-v)*cplx(0,-0.25);
}
FFT(B,lgn,-1);
for(int i=0;i<n+m;i++) printf("%d ",int(round(B[i].r)));
printf("%d\n",(int)round(B[n+m].r));
}
MTT
合并DFT详见myy论文。
合并IDFT其实不需要任何技巧因为:
如果觉得慢的话可以将long double改为double
然后预处理单位根照样可以满足
的精度要求
#include<bits/stdc++.h>
#define maxn 300005
#define LL long long
#define M ((1<<15)-1)
#define ld long double
using namespace std;
char cb[1<<15],*cs=cb,*ct=cb;
#define getc() (cs==ct&&(ct=(cs=cb)+fread(cb,1,1<<15,stdin),cs==ct)?0:*cs++)
inline void read(int &res){ char ch;for(;!isdigit(ch=getc()););for(res=ch-'0';isdigit(ch=getc());res=res*10+ch-'0'); }
int p;
const ld Pi = 3.1415926535897932384626433832795;
struct cplx
{
ld r,i;
cplx(ld r=0,ld i=0):r(r),i(i){}
cplx operator +(const cplx &B)const{ return cplx(r+B.r,i+B.i); }
cplx operator -(const cplx &B)const{ return cplx(r-B.r,i-B.i); }
cplx operator *(const cplx &B)const{ return cplx(r*B.r-i*B.i,i*B.r+B.i*r); }
cplx conj(){ return cplx(r,-i); }
}w[maxn]={1};
int a[maxn],b[maxn],c[maxn],r[maxn];
inline void FFT(cplx A[maxn],int lgn,int tp)
{
int n = 1<<lgn;
for(int i=1;i<n;i++) r[i] = (r[i>>1]>>1)|((i&1)<<(lgn-1));
for(int i=1;i<n;i++) if(i<r[i])swap(A[i],A[r[i]]);
for(int L=2;L<=n;L<<=1)
{ int l=L>>1;w[1]=cplx(cos(Pi/l),sin(Pi/l)*tp);
for(int i=2;i<l;i++) w[i] = w[i-1] * w[1];
for(int st=0;st<n;st+=L)
for(int k=0;k<l;k++)
{
cplx tmp = w[k] * A[st+k+l];
A[st+k+l] = A[st+k]-tmp , A[st+k] = A[st+k] + tmp;
}
}
if(tp == -1) for(int i=0;i<n;i++) A[i].r/=n,A[i].i/=n;
}
cplx s[4][maxn];
inline void mul(int a[maxn],int b[maxn],int lgn,int c[maxn])
{
int n = 1<<lgn;
for(int i=0;i<n;i++) s[0][i] = cplx(a[i]>>15,b[i]>>15) , s[1][i] = cplx(a[i]&M,b[i]&M);
FFT(s[0],lgn,1),FFT(s[1],lgn,1);
for(int i=0;i<n;i++)
{
cplx a[4] = {s[0][i] , s[0][(n-i)&(n-1)].conj() , s[1][i] , s[1][(n-1)&(n-i)].conj()};
cplx b[4] = {(a[0]+a[1])*cplx(0.5,0),(a[0]-a[1])*cplx(0,-0.5),
(a[2]+a[3])*cplx(0.5,0),(a[2]-a[3])*cplx(0,-0.5)};
s[2][i] = b[0]*b[1]+cplx(0,1)*(b[2]*b[3]) , s[3][i] = b[0]*b[3]+cplx(0,1)*b[1]*b[2];//IDFT(DFT(A(x))+DFT(iB(x))) = A(x) + iB(x)
}
FFT(s[2],lgn,-1),FFT(s[3],lgn,-1);
for(int i=0;i<n;i++)
{
LL a[4] = {llround(s[2][i].r)%p,llround(s[2][i].i)%p,llround(s[3][i].r)%p,llround(s[3][i].i)%p};
c[i] = (a[1] + (((a[2]+a[3])%p)<<15) + (a[0]<<30)) % p;
}
}
int main()
{
int n,m;p=1000000007;
read(n),read(m);
for(int i=0;i<=n;i++) read(a[i]);
for(int i=0;i<=m;i++) read(b[i]);
int lgn = 0;
for(;n+m>=(1<<lgn);lgn++);
mul(a,b,lgn,c);
for(int i=0;i<n+m;i++) printf("%d ",(c[i]+p)%p);
printf("%d\n",(c[n+m]+p)%p);
}
2.多项式求逆:
MTT版本:
#include<bits/stdc++.h>
#define maxn 300005
#define mod 1000000007
#define LL long long
#define M ((1<<15)-1)
#define ld long double
using namespace std;
const ld Pi = 3.1415926535897932384626433832795;
int wlen=0;
struct cplx
{
ld r,i;
cplx(ld r=0,ld i=0):r(r),i(i){}
cplx operator +(const cplx &B)const{return cplx(r+B.r,i+B.i);}
cplx operator -(const cplx &B)const{return cplx(r-B.r,i-B.i);}
cplx operator *(const cplx &B)const{return cplx(r*B.r-i*B.i,i*B.r+r*B.i);}
cplx conj(){return cplx(r,-i);}
}w[maxn];
inline int Pow(int base,int k)
{
int ret = 1;
for(;k;k>>=1,base=1ll*base*base%mod) if(k&1) ret=1ll*ret*base%mod;
return ret;
}
int r[maxn];
inline void FFT(cplx A[maxn],int lgn,int tp)
{
int n = 1<<lgn;
for(int i=0;i<n;i++) r[i] = (r[i>>1]>>1)|((i&1)<<(lgn-1));
for(int i=0;i<n;i++) if(i < r[i]) swap(A[i] , A[r[i]]);
for(int L=2;L<=n;L<<=1)
for(int st=0,l=L>>1,inc=wlen/l;st<n;st+=L)
for(int k=0,lc=0;k<l;k++,lc+=inc)
{
cplx tmp = (tp==1 ? w[lc] : w[lc].conj()) * A[st+k+l];
A[st+k+l]=A[st+k]-tmp,A[st+k]=A[st+k]+tmp;
}
if(tp==-1) for(int i=0;i<n;i++) A[i].r/=n,A[i].i/=n;
}
cplx s[4][maxn];
inline void mul(int a[maxn],int b[maxn],int lgn,int c[maxn])
{
int n = 1<<lgn;
for(int i=0;i<n;i++) s[0][i]=cplx(a[i]>>15,b[i]>>15),s[1][i]=cplx(a[i]&M,b[i]&M);
FFT(s[0],lgn,1),FFT(s[1],lgn,1);
for(int i=0;i<n;i++)
{
cplx a[4]={s[0][i],s[0][(n-i)&(n-1)].conj(),s[1][i],s[1][(n-1)&(n-i)].conj()};
cplx b[4]={(a[0]+a[1])*cplx(0.5,0),(a[0]-a[1])*cplx(0,-0.5),
(a[2]+a[3])*cplx(0.5,0),(a[2]-a[3])*cplx(0,-0.5)};
s[2][i] = b[0]*b[1]+cplx(0,1)*b[2]*b[3],s[3][i]=b[0]*b[3]+cplx(0,1)*b[1]*b[2];
}
FFT(s[2],lgn,-1),FFT(s[3],lgn,-1);
for(int i=0;i<n;i++)
{
LL a[4]={llround(s[2][i].r)%mod,llround(s[2][i].i)%mod,llround(s[3][i].r)%mod,llround(s[3][i].i)%mod};
c[i] = ((a[0]<<30)+(a[1])+((a[2]+a[3])<<15)) % mod;
}
}
void Inv(int a[maxn],int lgn,int b[maxn])
{
if(lgn==0){ b[0]=Pow(a[0],mod-2);return; }
Inv(a,lgn-1,b);
int n = (1<<(lgn+1));
static int tmp[3][maxn];
for(int i=0;i<n;i++)
{
if(i<(n>>1)) tmp[0][i] = a[i]; else tmp[0][i] = 0;
if(i<(n>>2)) tmp[1][i] = b[i]; else tmp[1][i] = 0;
}
mul(tmp[0],tmp[1],lgn+1,tmp[2]);
for(int i=0;i<n;i++) tmp[2][i] = (i==0) * 2 - tmp[2][i];
mul(tmp[2],tmp[1],lgn+1,b);
for(int i=(n>>1);i<n;i++) b[i] = 0;
}
int n,a[maxn],b[maxn];
int main()
{
int n;
scanf("%d",&n);
for(int i=0;i<n;i++) scanf("%d",&a[i]);
int lgn = 0;
for(;n>(1<<lgn);lgn++);
wlen = 1<<(lgn+1);
for(int i=0;i<wlen;i++) w[i] = cplx(cos(i*Pi/wlen),sin(i*Pi/wlen));
Inv(a,lgn,b);
printf("%d",(b[0]+mod)%mod);
for(int i=1;i<n;i++) printf(" %d",(b[i]+mod)%mod);
}
多项式除法:
求
其中
已知。
不知道怎么理解。
多项式除法
还是手推一下式子比较好:
常系数线性递推
这篇看起来又线代又通俗
多点求值和快速插值
51nod 1387