学习自FFT详解。
很久前就想学,然而一直不能理解,这两天稍微懂了一些。
含义及用途
FFT(Fast Fourier Transformation)是离散傅氏变换(DFT)的快速算法。即为快速傅氏变换。它是根据离散傅氏变换的奇、偶、虚、实等特性,对离散傅立叶变换的算法进行改进获得的,把多项式乘法的复杂度从 降到了 (然而常数很大)。
过程
前置知识
多项式的两种表示方法
设多项式 。
系数表示法:
点值表示法:把多项式看成函数, 为函数上的点,则
单位复根
即 。其中 。
则有
DFT
把 的项按照次数奇偶分开,则有
令
则
这样可以一直递归下去。
然而会发现左右两半长度不相同,会没法搞,于是我们把位数补到 ,然后把 的单位复根代入即可。
然而递归很慢,我们可以尝试找找规律。
假设 的最高项次数为 。
我们来模拟一下:
把编号转化成二进制,前后比对:
然后通过看别人Blog观察可以发现(当然可以证明)最终编号就是原来编号在二进制下的翻转。
于是我们可以省去递归的操作,直接从底回溯即可(实际上用“蝴蝶操作”可以直接迭代,这里仅作代码注释)。
IDFT
把单位复根取倒数,做一遍DFT并把结果除以 就可以把点值表示转化为系数表示法。
因为博主太菜证不来这里不做证明,实现见代码。
模板
#include<cmath>
#include<cctype>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 1<<21
#define F inline
#define fr first
#define sc second
#define MP make_pair
using namespace std;
typedef long double DB;
typedef pair<DB,DB> P;
const DB pi=acos(-1);
int n,m,l,R[N],c[N];
P a[N],b[N];
F char readc(){
static char buf[100000],*l=buf,*r=buf;
if (l==r) r=(l=buf)+fread(buf,1,100000,stdin);
return l==r?EOF:*l++;
}
F int _read(){
int x=0; char ch=readc();
while (!isdigit(ch)) ch=readc();
while (isdigit(ch)) x=(x<<3)+(x<<1)+(ch^48),ch=readc();
return x;
}
F void writec(int x){ if (x>9) writec(x/10); putchar(x%10+48); }
F void _write(int x){ writec(x),putchar(' '); }
//以下三个运算为在复平面内的运算
F P operator + (P a,P b){ return MP(a.fr+b.fr,a.sc+b.sc); }
F P operator - (P a,P b){ return MP(a.fr-b.fr,a.sc-b.sc); }
F P operator * (P a,P b){ return MP(a.fr*b.fr-a.sc*b.sc,a.fr*b.sc+a.sc*b.fr); }
F void FFT(P *a,int n,int f){
for (int i=0;i<n;i++) if (i<R[i]) swap(a[i],a[R[i]]);
//把原来的转化为最终的多项式,这里大于小于都无所谓,只要不换回来就行
for (int k=1;k<n;k<<=1){//蝴蝶操作,k表示当前合并的大小
P w=MP(1,0),wn=MP(cos(pi/k),sin(f*pi/k)),x,y;
//wn为单位复根,w为sqrt(-1),x为f1,y为f2
for (int i=0;i<n;i+=(k<<1),w=MP(1,0))
for (int j=0;j<k;j++,w=w*wn)
x=a[i+j],y=w*a[i+j+k],a[i+j]=x+y,a[i+j+k]=x-y;
}
}
int main(){
n=_read(),m=_read();
for (int i=0;i<=n;i++) a[i]=MP(_read(),0);
for (int i=0;i<=m;i++) b[i]=MP(_read(),0);
for (m+=n,n=1;n<=m;n<<=1) l++;
for (int i=0;i<=n;i++)
R[i]=(R[i>>1]>>1)|((i&1)<<(l-1));//翻转
FFT(a,n,1),FFT(b,n,1);
for (int i=0;i<n;i++) a[i]=a[i]*b[i];
FFT(a,n,-1);
for (int i=0;i<=m;i++) _write(a[i].fr/n+0.5);
return 0;
}
经典应用:高精度乘法。
把每个数看作多项式每一项之前的系数即可。
代码:
#include<cmath>
#include<cctype>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 1<<18
#define F inline
using namespace std;
typedef long double DB;
const DB pi=acos(-1);
struct P{ DB x,y; }a[N],b[N];
int n,l,m,r[N],c[N],ans[N];
char bf[N];
F char readc(){
static char buf[100000],*l=buf,*r=buf;
if (l==r) r=(l=buf)+fread(buf,1,100000,stdin);
return l==r?EOF:*l++;
}
F int _read(){
int x=0; char ch=readc();
while (!isdigit(ch)) ch=readc();
return ch-48;
}
F P operator + (P a,P b){ return (P){a.x+b.x,a.y+b.y}; }
F P operator - (P a,P b){ return (P){a.x-b.x,a.y-b.y}; }
F P operator * (P a,P b){ return (P){a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x}; }
F void FFT(P *a,int f){
for (int i=0;i<n;i++) if (i<r[i]) swap(a[i],a[r[i]]);
for (int k=1;k<n;k<<=1){
P w={1,0},wn={cos(pi/k),sin(f*pi/k)},x,y;
for (int i=0;i<n;i+=(k<<1),w=(P){1,0})
for (int j=0;j<k;j++,w=w*wn)
x=a[i+j],y=w*a[i+j+k],a[i+j]=x+y,a[i+j+k]=x-y;
}
}
F void _write(){
for (int i=0;i<=m;i++) bf[i]=(char)ans[m-i]+48;
puts(bf);
}
int main(){
scanf("%d",&n);
for (int i=1;i<=n;i++)
a[n-i]=(P){_read(),0};
for (int i=1;i<=n;i++) b[n-i]=(P){_read(),0};
for (m=n<<1|1,n=1;n<m-1;n<<=1) l++;
for (int i=0;i<=n;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
FFT(a,1),FFT(b,1);
for (int i=0;i<n;i++) a[i]=a[i]*b[i];
FFT(a,-1);
for (int i=0;i<=n;i++){
ans[i]+=a[i].x/n+0.5;
ans[i+1]+=ans[i]/10,ans[i]%=10;
}
while (!ans[m]&&m) m--;
return _write(),0;
}