FFT快速傅里叶变换
说明
参考&鸣谢
小学生都能看懂的FFT!!!
十分简明易懂的FFT(快速傅里叶变换)
浅谈 FFT (终于懂一点了~~)
数学黑科技1——FFT
什么是FFT?
- FFT(Fast Fourier Transformation) 是离散傅氏变换(DFT)的快速算法。即为快速傅氏变换。它是根据离散傅氏变换的奇、偶、虚、实等特性,对离散傅立叶变换的算法进行改进获得的。——百度百科
- 这是什么并不重要。。。
用来做什么?
- 标准的FFT模板题,
- 设两个多项式:
-
A=i=0∑nai∗xi=an∗xn+an−1∗xn−1+...+a1∗x+a0,
-
B=i=0∑mbi∗xi=bn∗xm+bm−1∗xm−1+...+b1∗x+b0,
- 求
A∗B的值?
- 难道不就是
Ans=i=0∑n+msi∗xi=i=0∑n+mj=0∑iaj∗bi−j∗xi
- 是的,用
A的每一项和
B的每一项分别相乘,时间复杂度
O(nm).
- 但如何跑得更快呢?
- ——用FFT!
- 经典的应用,
- 高精度乘法(是不是十分显然)。
关于多项式
- 上面的多项式表示法是最普通的,也就是平时课本上看到的,称为“系数表示法“,因为显而易见每一项的系数分别是多少。
- 现在介绍另一种表示法——“点值表示法”,
- 一个
n次
n+1项式,可以将其看做是
n次函数的解析式,
- 为了唯一确定这个
n次多项式,需要选取
n+1个不同的
x代入该式子,分别得到一个对应的
y,
- 也就是,
-
an∗x0n+an−1∗x0n−1+...+a1∗x0+a0=y0,
-
an∗x1n+an−1∗x1n−1+...+a1∗x1+a0=y1,
-
……
-
an∗xnn+an−1∗xnn−1+...+a1∗xn+a0=yn,
- 那么就有两个多项式分别表示为(注意要取相同的且个数相同的
x),
-
A=((x0,yA0)),(x1,yA1),...,(xn,yAn))
-
B=((x0,yB0)),(x1,yB1),...,(xn,yBn))
- 则它们的积为,
-
Ans=((x0,yA0∗yB0),(x1,yA1∗yB1),...,(xn,yAn∗yBn))
- 把系数表示法变为点值表示法的运算叫做点值运算,
- 把点值表示法变为系数表示法的运算叫做插值运算,
- 这两种运算是互逆的,同时也一定是可逆的。
- 仿佛发现什么不可告人的秘密(难道就这么简单?)
- 按这样的步骤相乘不就是
O(n)的了?
- 没错,相乘是
O(n)的,但做点值和插值运算呢?还是
O(n2)的~~~
- 这两个运算的朴素过程分别被称为DFT(离散傅里叶变换)和IDFT(离散傅里叶逆变换)。
- 如何加速?
突破
- 傅里叶想到了要用单位根,
- 若
xn=1,则
x为
n次单位根,
- 那我们先想想,
- 实数范围内,
12=−12=1,
- 推广到虚数,
14=−14=i4=−i4=1,那也才
4个
x,远远不够。
- 那我们不如把实数和虚数结合吧,也就是复数!(高中会学)
为什么
关于复数
-
它的形如
a+bi的一个数,且
a,b∈R,
a为实部,
b为虚部,
-
关于它们的运算:
-
(a+bi)+(c+di)=(a+c)+(b+d)i
-
(a+bi)−(c+di)=(a−c)+(b−d)i
-
(a+bi)∗(c+di)=a∗c+a∗di+bi∗c+bi∗di=ac+(ad+bc)i+bd(i2)=(ac−bd)+(ad+bc)i
-
c+dia+bi=(c+di)(c−di)(a+bi)(c+di)=c2+d2(ac−bd)+(ad+bc)i
-
(除法可以不用管)
-
复数可以放在一个坐标系(复平面直角坐标系)中表示,
-
对于复数
a+bi,在坐标系中的坐标为
(a,b),
-
-
点到原点的线段称为模,模与
R轴(实数轴)正半轴的夹角称为幅角,
-
将复数乘法放在坐标系中,那么答案的位置就是相乘的两数模长相乘,幅角相加,
-
那么以原点为圆心,
1个单位长度为半径作圆,得到圆称为单位圆,
-
将圆从
(1,0)开始
n等分,第二个点为
wn1,就是
n次单位根,坐标为
(cos(n2π),sin(n2π))
-
则根据乘法在坐标系上的法则,可以得到每个点的表示的复数分别为
wn0,wn1..wnn−1,(
wn0=wnn)
-
wni坐标为
(cos(n2π∗i),sin(n2π∗i)),
π是按弧度制的等于
180°
-
(如
n=3)
-
它们都可以满足
(wni)n=1,
-
首先,模长无论乘多少次都是
1,
-
其次,
wni的幅角的
n倍为
n360°∗i∗n=360°∗i,也就是绕着原点转了
i圈,回到
x轴正半轴,
-
因此得证。
-
接下来要看单位根的一些性质,
-
一、
wdndi=wni
-
证明:
-
wdndi=cos(dn2π∗di)+sin(dn2π∗di)∗i=cos(n2π∗i)+sin(n2π∗i)∗i=wni
-
二、
wnk=−wnk+2n
-
相当于将这个点旋转
180°,则到了与它关于原点对称的位置,
-
证明:
-
wn2n=cos(n2π∗2n)+sin(n2π∗2n)∗i=cos(π)+sin(π)∗i=−1
-
wnk=−wnk∗(−1)=−wnk∗wn2n=−wnk+2n
FFT
- 终于可以进入正题了(铺垫实在是太多了)!!!
- 既然我们得出了单位根,那它一定能发挥一定作用,
- 根据上面的一些性质,DFT可以用分治来完成,时间复杂度优化到
O(nlog2n).
- 设
n项式
A(x),
-
A(x)=i=0∑n−1ai∗xi=an∗xn+an−1∗xn−1+...+a1∗x+a0
- 把偶数次项和奇数次项分开,
-
A(x)=(an−2∗xn−2+an−4∗xn−4+...+a2∗x2+a0)+(an−1∗xn−1+an−3∗xn−3+...+a3∗x3+a1∗x)
-
=(an−2∗xn−2+an−4∗xn−4+...+a2∗x2+a0)+x∗(an−1∗xn−2+an−3∗xn−4+...+a3∗x2+a1)
- 发现两边括号内每一项的指数已经相等了,那么再设
-
A0(x)=an−2∗xn−1+an−4∗xn−2+...+a2∗x+a0
-
A1(x)=an−1∗xn−1+an−3∗xn−2+...+a3∗x+a1
- (注意指数与
a的下标不是相等的了,而是变成了
21)
- 于是可以得到,
-
A(x)=A0(x2)+x∗A1(x2)
- 对于代入的
wnk(k<2n)分两种情况,有
-
A(wnk)
-
=A0((wnk)2)+wnk∗A1((wnk)2)
-
=A0(wn2k)+wnk∗A1(wn2k)
-
=A0(w2nk)+wnk∗A1(w2nk)
-
A(wnk+2n)
-
=A0((wnk+2n)2)+wnk+2n∗A1((wnk+2n)2)
-
=A0(wn2k+n)+wnk+2n∗A1(wn2∗k+n)
-
=A0(wnn∗wn2k)+wnk+2n∗A1(wnn∗wn2k)
-
=A0(w2nk)−wnk∗A1(w2nk)
- 因此,我们知道了
A0(x)和
A1(x)分别在
(w2n0,w2n1,...,w2n2n−1)的点值表示,就可以算出
A(x)在
(wn0,w01,...,wnn−1)的点值表示了。
- 显然这就是一个分治的过程,边界为
n=1
- 记得在做之前把
n补到
2的幂数,多出部分的系数就全都为
0,.
- 点值运算完成后,再根据上面的【关于多项式】和【为什么】的内容,就可以得出最终的答案了,
- 务必记得最终的多项式系数要全部除以
n.
- 这样一来,时间就可以优化到
O(nlogn).
板子(洛谷3809过不去的)
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
#define ld long double
#define N 2000010
const ld pi=acos(-1.0);
ld read()
{
ld s=0;
int x=getchar();
while(x<'0'||x>'9') x=getchar();
while(x>='0'&&x<='9') s=s*10+x-'0',x=getchar();
return s;
}
struct node
{
ld x,y;
node(ld xx=0,ld yy=0)
{
x=xx,y=yy;
}
}a[N],b[N];
node operator +(node x,node y)
{
return node{x.x+y.x,x.y+y.y};
}
node operator -(node x,node y)
{
return node{x.x-y.x,x.y-y.y};
}
node operator *(node x,node y)
{
return node(x.x*y.x-x.y*y.y,x.x*y.y+x.y*y.x);
}
void FFT(int len,node *a,int p)
{
if(len==1) return;
node a0[len/2+3],a1[len/2+3];
for(int i=0;i<=len;i+=2) a0[i/2]=a[i],a1[i/2]=a[i+1];
FFT(len/2,a0,p);
FFT(len/2,a1,p);
node w=node(cos(2*pi/len),p*sin(2*pi/len));
node w0=node(1,0);
for(int i=0;i<len/2;i++)
{
a[i]=a0[i]+w0*a1[i];
a[i+len/2]=a0[i]-w0*a1[i];
w0=w0*w;
}
}
int main()
{
int n,m,i;
n=read(),m=read();
for(i=0;i<=n;i++) a[i].x=read();
for(i=0;i<=m;i++) b[i].x=read();
int len=1;
for(;len<=n+m;len*=2);
FFT(len,a,1);
FFT(len,b,1);
for(i=0;i<=len;i++) a[i]=a[i]*b[i];
FFT(len,a,-1);
for(i=0;i<=n+m;i++) printf("%.0Lf ",a[i].x/len+1e-6);
return 0;
}
原因
解决办法
优化
- 原来的做法是先分别读入
a,b的实部(因为它们没有虚部),然后分别做点值运算,相乘后做插值运算。
- 现在把
b的实部直接作为
a的虚部,设改变之后的系数分别为
ci,
- 那么
ci=ai+bi∗i,
- 有
c2=a2−b2+2abi,那么虚部的
2ab在原先必须除以
n的基础上再多除以
2就是答案。
- 所以时间可以优化。
真·优化
- 想办法把
DFS去掉。
- 设原序列
(A0,A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15),
- 分治每次变换后:
-
(A0,A2,A4,A6,A8,A10,A12,A14),(A1,A3,A5,A7,A9,A11,A13,A15)
-
(A0,A4,A8,A12),(A2,A6,A10,A14),(A1,A5,A9,A13),(A3,A7,A11,A15)
-
(A0,A8),(A4,A12),(A2,A10),(A6,A14),(A1,A9),(A5,A13),(A3,A11),(A7,A15)
- 到每组连个数时,把它们的下标和所在位置排列起来:
-
(0,8,4,12,2,10,6,14,1,9,5,13,3,11,7,15)
-
(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15)
- 它们对应的二进制是:
-
(0000,1000,0100,1100,0010,1010,0110,1110,0001,1001,0101,1101,0011,1011,0111,1111)
-
(0000,0001,0010,0011,0100,0101,0110,0111,1000,1001,1010,1011,1100,1101,1110,1111)
- 发现对应的每一位都是相反的,
- 那就根据这个把最终的序列求出来,再一步步反推回去。
- 以上被称为蝴蝶变换。
真·板子
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
#define ld double
#define N 8000010
#define pi M_PI
ld read()
{
ld s=0;
int x=getchar();
while(x<'0'||x>'9') x=getchar();
while(x>='0'&&x<='9') s=s*10+x-'0',x=getchar();
return s;
}
struct node
{
ld x,y;
node(ld xx=0,ld yy=0)
{
x=xx,y=yy;
}
}a[N];
int rev[N];
node operator +(node x,node y)
{
return node{x.x+y.x,x.y+y.y};
}
node operator -(node x,node y)
{
return node{x.x-y.x,x.y-y.y};
}
node operator *(node x,node y)
{
return node(x.x*y.x-x.y*y.y,x.x*y.y+x.y*y.x);
}
void FFT(int len,ld p)
{
for(int i=0;i<len;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=1;i<len;i*=2)
{
node w=node(cos(pi/i),p*sin(pi/i));
for(int j=0;j<len;j+=2*i)
{
node w0=node(1,0);
for(int k=0;k<i;k++)
{
node A=a[k+j],B=a[k+j+i];
a[k+j]=A+B*w0;
a[k+j+i]=A-B*w0;
w0=w0*w;
}
}
}
}
int main()
{
int n,m,i;
n=read(),m=read();
for(i=0;i<=n;i++) a[i].x=read();
for(i=0;i<=m;i++) a[i].y=read();
int len=1,ls=0;
for(;len<=n+m;len*=2,ls++);
for(i=0;i<len;i++)
{
int t=i;
for(j=1;j<=ls;j++) rev[i]=rev[i]*2+t%2,t/=2;
}
FFT(len,1);
for(i=0;i<=len;i++) a[i]=a[i]*a[i];
FFT(len,-1);
for(i=0;i<=n+m;i++) printf("%.0lf ",a[i].y/2/len+1e-3);
return 0;
}