多项式乘法之FFT快速傅里叶变换

写在前面的话:

快速傅里叶应用于多项式,来加速多项式乘法使其复杂度从o(n^{2})变成o(n(\log_{2}n)),好了先一步步来,这个算法确实很难,但也不是完全理解不了,只是方式不对。

首先,正常的多项式相乘是:

A[n]=a_{0}+_{}a_{1}x+a_{2}x^{2}+a_{3}x^{3}+~~+a_{n-1}x^{n-1},B[n]=b_{0}+b_{1}x^{}+b_{2}x^{2}+b_{3}+~~b_{n-1}x^{n-1};

C=A[n]*B[n]=a_{0}\times b_{0}+a_{0}\times b_{1}x^{}+~~+a_{n-1}\times b_{n-1}x^{2n-2}

正常肯定因为这个应该不能再被简化了,但是没有绝对的最简,人们觉得这个复杂度太高,没有艺术美,于是FFT诞生了,多项式再是局限于o(n^{2})。

先讲清楚它的思想支柱-点值法。多个点可以确定一个函数,而一个函数也可以分解为多个点。

\begin{Bmatrix} x_{00} ,x_{01},x_{02} ,..., x_{0n-1} \\ x_{10} ,x_{11},x_{13},....,x_{1n-1} \\ x_{20} ,x_{21},x_{23},....,x_{2n-1} \\ ..... ,.....,......,......,....\\ x_{n-10} ,x_{n-11},x_{n-13},....,x_{n-1n-1}\end{Bmatrix}

比如你解一个二元一次方程需要{a+bx=y, 2a+bx=y}如果想解出a,b的值是不是至少需要两组x,y的值。代入,然后解方程。换言之,表示一个n次多项式是不是可以用n组解来表示。记不记得高斯消元法,就是这个,n组解确定,代入方程,n组方程,利用高斯消元法,就可以得到一组解,就可以确定原方程。

下面问题来了,点值法可以用,但是如何找到n组合适的解呢?不然就算点值法做,复杂度也是不会变的。

好了主题来了,轮到虚数出场了,问题的关键就是找到合适的解。

扫描二维码关注公众号,回复: 4123470 查看本文章

先来看一下虚数的定义:a+bi,a是实根,b是虚根。如果a+bi的长度为一的话,a,b可以怎么表示呢,a=cos\alpha,b=sin\alpha。原式等于cos\alpha+sin\alphai。这里扯一下著名的欧拉公式e^{i\alpha }=\cos\alpha +i\sin \alpha,当\alpha=\pi时,e^{i\pi }+1=0。美!公式本身就是一个艺术。证明用泰勒展开。

好了回归正题,e^{\frac{2\pi}{n}\times i }= \cos(\frac{2 \pi }{n})+i\times \sin (\frac{2 \pi }{n})e^{ \frac{2\times\pi }{n}\times i}\times e^{ \frac{2\times\pi }{n}\times i} = e^{\frac{2\times \pi }{\frac{n}{2}}},也就是角度乘2。

关键点来了:

w_{00}为1,w_{01}表示e^{\frac{2\times \pi }{n}\times i},w_{01}就是e^{ \frac{2\times\pi }{n}\times i}\times e^{ \frac{2\times\pi }{n}\times i} = e^{\frac{2\times \pi }{\frac{n}{2}}}对应上面的解矩阵,开始是1+x+x^2+x^3+.....+x^(n-1)表示一组解x=e^{\frac{2\times \pi }{n}\times i}。下一组解w_{10}=1,好了关键中的关键来了,第二组解就是在第一组解的基础上乘本身e^{ \frac{2\times\pi }{n}\times i}\times e^{ \frac{2\times\pi }{n}\times i} = e^{\frac{2\times \pi }{\frac{n}{2}}},然后继续乘x=e^{\frac{2\times \pi }{\frac{n}{2}}\times i},第三组同理。好了,开始神奇的变换了。

A(x)=a0+a1∗x+a2∗x^2+a3∗x^3+a4∗x^4+a5∗x^5+⋯+an−2∗x^n−2+an−1∗x^n−1。

A(x)=(a0+a2∗x^2+a4∗x^4+⋯+an−2∗x^n−2)+(a1∗x+a3∗x^3+a5∗x^5+⋯+an−1∗x^n−1)

A1(x)=a0+a2∗x+a4∗x^2+⋯+an−2∗x^n/2−1

A2(x)=a1+a3∗x+a5∗x^2+⋯+an−1∗x^n/2−1

A(x)=A1(x^2)+x*A2(x^2)

而上面已经证明e^{ \frac{2\times\pi }{n}\times i}\times e^{ \frac{2\times\pi }{n}\times i} = e^{\frac{2\times \pi }{\frac{n}{2}}},A(w_{n}^{k})=A1(w_{n}^{2k})+w_{n}^{k}A2(w_{n}^{2k})=A1(w_{\frac{n}{2}}^{k})+w_{n}^{k}A2(w_{\frac{n}{2}}^{k})

A(w_{n}^{k+\frac{n}{2}})=A1(w_{n}^{2k+n}})+w_{n}^{k+\frac{n}{2}}A(w_{n}^{2k+n}})=A1(w_{n}^{2k}\times w_{n}^{n})−w_{n}^{k}A2(w_{n}^{2k}\times w_{n}^{n})=A1(w_{\frac{n}{2}}^{k})-w_{n}^{k}A2(w_{\frac{n}{2}}^{k})

现在我们就可以得出A(w_{n}^{k}) (k\leqslant\frac{n}{2})是前一半,而后一半直接就可以得出A(w_{n}^{k+\frac{n}{2}}),复杂度减一半,递归下去。

下面是递归代码实现:

double PI = acos(-1);
class complex {
private:
	int r, i;
public:
	complex(double _r=0,double _i=0):r(_r),i(_i){}
	complex operator+(const complex &u) { return complex(r+u.r,i+u.i); }
	complex operator-(const complex &u) { return complex(r - u.r, i - u.i); }
	complex operator*(const complex &u) { return complex(r*u.r - i * u.i, r*u.i + i * u.r); }
};
void FFT(complex *a,int n,int type) {
	if (n == 1)return;
	complex *a1 = new complex[n >> 1], *a2 = new complex[n >> 1];
	for (int i = 0; i <n; i+=2) {
		a1[i>>1] = a[i];
		a2[i >>1] = a[i + 1];
	}
	FFT(a1, n >> 1, type);//递归偶数部分
	FFT(a2, n >> 1, type);//递归奇数部分
	complex wn(cos(2 * PI / n), type*sin(2 * PI / n)), w(1, 0);
	/*这段代码是核心,a[0]~a[n]相当于y[0]~y[n]
	a[i + n >> 1] = a1[i] - w * a2[i];相当于计算下一半
	*/
	for (int i = 0; i < n >> 1; i++,w=w*wn) {
		a[i] = a1[i] + w * a2[i];//前一半
		a[i + n >> 1] = a1[i] - w * a2[i];//后一半
	}
	delete &a1; delete &a2;
}

现在已经将系数转换为点值了,点值相乘,然后再转换为系数。

XA=Y,X^{-1}*X*A=X^{-1}*Y

A=X^{-1}*Y,A就是解方程的系数矩阵,

那么现在只需要求X的逆矩阵就可以了,回到上面的欧拉公式,e^{\frac{2\pi}{n}\times i }= \cos(\frac{2 \pi }{n})+i\times \sin (\frac{2 \pi }{n})e^{\frac{2\pi}{n}\times i }\times e^{-\frac{2\pi }{n}\times i}= e^{0}=0

所以只需要将他的符号改一下就变成了逆矩阵

这又是一个FFT只不过符号要变为负的,最后值要除n。

c++complex模板实现:

#include <iostream>
using namespace std;
#include <complex>
#include <cmath>
double PI = acos(-1);
complex<double> a[400010], b[400010], c[400010];
void fft(complex<double> *a, int n, int op)
{
	if (n == 1) return;
	complex<double> w(1, 0), wn(cos(2 * PI*op / n), sin(2 * PI*op / n));
	complex<double>*a1 = new complex<double>[n >> 1], *a2 = new complex<double>[n >> 1];
	for (int i = 0; i < (n >> 1); i++)
		a1[i] = a[i << 1], a2[i] = a[(i << 1) + 1];
	fft(a1, n >> 1, op), fft(a2, n >> 1, op);
	for (int i = 0; i < (n >> 1); i++, w *= wn)
		a[i] = a1[i] + w * a2[i], a[i + (n >> 1)] = a1[i] - w * a2[i];
}
int main()
{
	int n, m;
	scanf("%d%d", &n, &m);
	for (int i = 0; i <= n; i++) scanf("%lf", &a[i]);
	for (int i = 0; i <= m; i++) scanf("%lf", &b[i]);
	m += n, n = 1;
	while (n <= m) n <<= 1;
	fft(a, n, 1), fft(b, n, 1);
	for (int i = 0; i < n; i++) c[i] = a[i] * b[i];
	fft(c, n, -1);
	for (int i = 0; i <= m; i++) printf("%d ", int(c[i].real() / n + 0.5));
	 //system("pause");
                 return 0;
}

实际算法中使用的是迭代法:

奇偶数可以用二进制反转来分,这样就不需要递归了,网上找的迭代法

#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <complex>
#define space putchar(' ')
#define enter putchar('\n')
using namespace std;
typedef long long ll;
template <class T>
void read(T &x){
    char c;
    bool op = 0;
    while(c = getchar(), c < '0' || c > '9')
    if(c == '-') op = 1;
        x = c - '0';
    while(c = getchar(), c >= '0' && c <= '9')
        x = x * 10 + c - '0';
    if(op) x = -x;
}
template <class T>
void write(T x){
    if(x < 0) putchar('-'), x = -x;
    if(x >= 10) write(x / 10);
    putchar('0' + x % 10);
}
const int N = 1000005;
const double PI = acos(-1);
typedef complex <double> cp;
char sa[N], sb[N];
int n = 1, lena, lenb, res[N];
cp a[N], b[N], omg[N], inv[N];
void init(){
    for(int i = 0; i < n; i++){
        omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
        inv[i] = conj(omg[i]);
    }
}
void fft(cp *a, cp *omg){
    int lim = 0;
    while((1 << lim) < n) lim++;
    for(int i = 0; i < n; i++){
        int t = 0;
        for(int j = 0; j < lim; j++)
            if((i >> j) & 1) t |= (1 << (lim - j - 1));//每次移位,比较第一位看是否为1,如果是1,那么对应的另一端要变为1
        if(i < t) swap(a[i], a[t]); // i < t 的限制使得每对点只被交换一次(否则交换两次相当于没交换)
    }
    for(int l = 2; l <= n; l *= 2){
        int m = l / 2;
    for(cp *p = a; p != a + n; p += l)
        for(int i = 0; i < m; i++){
            cp t = omg[n / l * i] * p[i + m];//这步看不懂的可以推一下总和等于奇数加偶数的公式,每次提取出的数是从小到大的
            p[i + m] = p[i] - t;
            p[i] += t;
        }
    }
}
int main(){
    scanf("%s%s", sa, sb);
    lena = strlen(sa), lenb = strlen(sb);
    while(n < lena + lenb) n *= 2;
    for(int i = 0; i < lena; i++)
        a[i].real(sa[lena - 1 - i] - '0');
    for(int i = 0; i < lenb; i++)
        b[i].real(sb[lenb - 1 - i] - '0');
    init();
    fft(a, omg);
    fft(b, omg);
    for(int i = 0; i < n; i++)
        a[i] *= b[i];
    fft(a, inv);
    for(int i = 0; i < n; i++){
        res[i] += floor(a[i].real() / n + 0.5);
        res[i + 1] += res[i] / 10;
        res[i] %= 10;
    }
    for(int i = res[lena + lenb - 1] ? lena + lenb - 1: lena + lenb - 2; i >= 0; i--)
        putchar('0' + res[i]);
    enter;
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_36973725/article/details/84074599
今日推荐