最近重新学了下卷积,简单总结一下,不涉及细节内容:
1、FFT
朴素求法:$Coefficient-O(n^2)-CoefficientResult$
FFT:$Coefficient-O(nlogn)-Dot-O(n)-DotResult-O(nlogn)-CoefficientResult$
其中系数到点值的转化称为$DFT(离散傅里叶变换)$,而点值到系数的转为称为$IDFT(傅里叶逆变换)$
原本朴素的直接带入$n$个值的$DFT$和直接使用拉格朗日插值公式的$IDFT$的复杂度仍为$O(n^2)$
但$FFT$通过带入特定的值:单位根,使得两者都能迭代/分治得解决,将复杂度降到了$O(nlogn)$
优化的技巧和注意事项:
1、预处理$w[i]$
2、求出最终数组从后往前迭代省去递归常数
3、数组长度要先扩成2的倍数用于分治
模板:
#include <bits/stdc++.h> using namespace std; #define X first #define Y second #define pb push_back typedef double db; typedef long long ll; typedef pair<int,int> P; const int MAXN=3e6+10; struct Complex { db x,y; Complex(db a=0,db b=0){x=a;y=b;} Complex operator + (const Complex& rhs) {return Complex(x+rhs.x,y+rhs.y);} Complex operator - (const Complex& rhs) {return Complex(x-rhs.x,y-rhs.y);} Complex operator * (const Complex& rhs) {return Complex(x*rhs.x-y*rhs.y,x*rhs.y+y*rhs.x);} }a[MAXN],b[MAXN]; int n,m,lmt=1,dgt,par[MAXN]; void FFT(Complex *a,int flag) { for(int i=0;i<lmt;i++) if(i<par[i]) swap(a[i],a[par[i]]); for(int len=1;len<lmt;len<<=1) { Complex unit(cos(M_PI/len),flag*sin(M_PI/len)); for(int st=0;st<lmt;st+=(len<<1)) { Complex w(1,0); for(int k=st;k<st+len;k++,w=w*unit) { Complex A=a[k],B=w*a[k+len]; a[k]=A+B;a[k+len]=A-B; } } } if(flag==-1) for(int i=0;i<=n+m;i++) a[i].x=floor(a[i].x/lmt+0.5); } int main() { scanf("%d%d",&n,&m); for(int i=0;i<=n;i++) scanf("%lf",&a[i].x); for(int i=0;i<=m;i++) scanf("%lf",&b[i].x); while(lmt<=n+m) lmt<<=1,dgt++; for(int i=0;i<lmt;i++) par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1)); FFT(a,1);FFT(b,1); for(int i=0;i<lmt;i++) a[i]=a[i]*b[i]; FFT(a,-1); for(int i=0;i<=n+m;i++) printf("%d ",(int)a[i].x); return 0; }
2、NTT
单位根由于涉及了复数的运算,导致对精度要求高时会出错
而$NTT$就能使得整个$FFT$都能在模意义下计算,从而满足精度要求
考虑$FFT$引入单位根$w_n^k$是为了其什么性质来分治计算:
1、$w_n^k$互不相同,保证点值表示的合法
2、$w_{t*n}^{t*k}=w_n^k$且$w_n^{k+2/n}=w_n^k$,使得计算可分治
3、$\sum_{i=0}^{n-1} {w_n^k}^i=n*[k==0]$,保证逆矩阵构造的正确性
在模意义下引入质数$p=kn+1$,其原根$g$满足$g_t(t\in [0,p-1])$互不相同
这样令$p$的$k$次单位根为$g^{\frac{p-1}{k}}$,易证上述$w_n_k$的性质其在模意义下均满足
接下来考虑该怎样选择质数$p$
为了能够分治时允许$k$每次乘2,$p-1$的质因数分解中要有很多的2
令$p=r*2^k+1$,其能处理的数据规模为$[0,2^k]$,常用质数有:传送门
这样,我们就在模意义下利用原根的性质找到了可做$FFT$的“单位根”
由于没有了复数运算,$NTT$比$FFT$的常数也小了很多,一般是更好的选择
模板:
#include <bits/stdc++.h> using namespace std; #define X first #define Y second #define pb push_back typedef double db; typedef long long ll; typedef pair<int,int> P; const int MAXN=4e6+10,MOD=998244353; ll n,m,a[MAXN],b[MAXN],dgt,lmt=1,par[MAXN]; ll quick_pow(ll a,ll b) { ll ret=1; for(;b;b>>=1,a=a*a%MOD) if(b&1) ret=ret*a%MOD; return ret; } void FFT(ll *a,int flag) { for(int i=0;i<lmt;i++) if(i<par[i]) swap(a[i],a[par[i]]); for(int len=1;len<lmt;len<<=1) { ll unit=quick_pow(3,(MOD-1)/(len<<1)); if(flag==-1) unit=quick_pow(unit,MOD-2); for(int st=0;st<lmt;st+=(len<<1)) { ll w=1; for(int k=st;k<st+len;k++,w=w*unit%MOD) { ll A=a[k],B=w*a[k+len]%MOD; a[k]=(A+B)%MOD;a[k+len]=(A-B+MOD)%MOD; } } } } int main() { scanf("%lld%lld",&n,&m); for(int i=0;i<=n;i++) scanf("%lld",&a[i]); for(int i=0;i<=m;i++) scanf("%lld",&b[i]); while(lmt<=n+m) lmt<<=1,dgt++; for(int i=0;i<lmt;i++) par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1)); FFT(a,1);FFT(b,1); for(int i=0;i<lmt;i++) (a[i]*=b[i])%=MOD; FFT(a,-1); ll inv=quick_pow(lmt,MOD-2); for(int i=0;i<=n+m;i++) printf("%lld ",a[i]*inv%MOD); return 0; }
3、MTT
如果答案需要取模且模数非质数该如何处理呢?
常见背景为:多项式长度$1e5$,模数$1e9$非质数,此时$FFT$爆$longlong$,没法用$NTT$
(1)三模数$NTT$
根据上方的数据限制,可发现最终答案最多为$1e23$
这样就能用多个乘积大于$1e23$的模数分别做$NTT$最后再用$CRT$合并答案即可
一般常用:469762049,998244353,1004535809
可如果直接用$CRT$合并会发现模数爆$longlong$还是不好处理
此时就可以先合并前两个式子,得到
$res=k(mod(p_1*p_2)),res=a_3(mod(p_3))$
这样设$res=p_1*p_2*c+k$再带入二式就能得到$c=(a_3-k)*(p_1*p_2)^{-1}(mod(p_3))$
这样类似$exCRT$的分步处理就避开了对$p_1*p_2*p_3$的取模
但这样要进行9次DFT/IDFT,常数巨大无比
模板:
#include <bits/stdc++.h> using namespace std; #define X first #define Y second #define pb push_back typedef double db; typedef long long ll; typedef pair<int,int> P; const int MAXN=4e5+10; ll p[]={469762049,998244353,1004535809}; int n,m,MOD,F[MAXN],G[MAXN],dgt,lmt=1; ll a[3][MAXN],b[MAXN],res[MAXN],par[MAXN]; ll quickpow(ll a,ll b,ll MOD) { a%=MOD;ll ret=1; for(;b;b>>=1,a=a*a%MOD) if(b&1) ret=ret*a%MOD; return ret; } ll mul(ll a,ll b,ll MOD) { a=(a%MOD+MOD)%MOD; b=(b%MOD+MOD)%MOD;ll ret=0; for(;b;b>>=1,a=(a+a)%MOD) if(b&1) (ret+=a)%=MOD; return ret; } ll inv(ll a,ll MOD) {return quickpow(a,MOD-2,MOD);} void FFT(ll *a,int flag,ll MOD) { for(int i=0;i<lmt;i++) if(i<par[i]) swap(a[i],a[par[i]]); for(int len=1;len<lmt;len<<=1) { ll unit=quickpow(3,(MOD-1)/(len<<1),MOD); if(flag==-1) unit=inv(unit,MOD); for(int st=0;st<lmt;st+=(len<<1)) { ll w=1; for(int k=st;k<st+len;k++,w=w*unit%MOD) { ll A=a[k],B=w*a[k+len]%MOD; a[k]=(A+B)%MOD;a[k+len]=(A-B+MOD)%MOD; } } } if(flag==-1) { ll INV=inv(lmt,MOD); for(int i=0;i<lmt;i++) a[i]=a[i]*INV%MOD; } } void solve(ll *a,ll *b,ll MOD) { for(int i=0;i<=n;i++) a[i]=F[i]; for(int i=0;i<=m;i++) b[i]=G[i]; for(int i=m+1;i<lmt;i++) b[i]=0; FFT(a,1,MOD);FFT(b,1,MOD); for(int i=0;i<lmt;i++) a[i]=a[i]*b[i]%MOD; FFT(a,-1,MOD); } int main() { scanf("%d%d%d",&n,&m,&MOD); for(int i=0;i<=n;i++) scanf("%d",&F[i]); for(int i=0;i<=m;i++) scanf("%d",&G[i]); while(lmt<=n+m) lmt<<=1,dgt++; for(int i=0;i<lmt;i++) par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1)); for(int i=0;i<3;i++) solve(a[i],b,p[i]); for(int i=0;i<=n+m;i++) { ll M=p[0]*p[1]; ll A=(mul(a[0][i]*p[1],inv(p[1],p[0]),M)+ mul(a[1][i]*p[0],inv(p[0],p[1]),M))%M; ll K=mul(a[2][i]-A,inv(M,p[2]),p[2]); res[i]=(mul(K,M,MOD)+A%MOD)%MOD; } for(int i=0;i<=n+m;i++) printf("%lld ",res[i]); return 0; }