快速傅里叶变换模板记录

例题记得应该是 hdu的 1402,傅里叶变换是一个模板。怎么写是大致固定的,但是如何使用就有考究了。具体傅里叶的讲解我存上一片博客,写的相当不错。FFT详解---在此声明地址(https://blog.csdn.net/GGN_2015/article/details/68922404)。非常佩服这样的同学,不仅知识学的多,而且还可以写的文章清晰明了,讲解清楚。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <string>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <time.h>
#include <complex>

using namespace std;
#define lson 2*i
#define rson 2*i+1
#define LS l,mid,lson
#define RS mid+1,r,rson
#define UP(i,x,y) for(i=x;i<=y;i++)
#define DOWN(i,x,y) for(i=x;i>=y;i--)
#define MEM(a,x) memset(a,x,sizeof(a))
#define gcd(a,b) __gcd(a,b)
#define LL long long
#define N 1000005
#define MOD 1000000007
#define INF 0x3f3f3f3f
#define EXP 1e-8
#define lowbit(x) (x&-x)
typedef complex<double> CD;		//复数类的定义
const int maxl = 2094153;		//nlogn的最大长度
const double PI = acos(-1.0);	//圆周率
CD a[maxl],b[maxl];				//用于存储变换的中间结果
int rev[maxl];					//二进制翻转结果
int ans[maxl];
int T;
char s1[maxl],s2[maxl];
int n,m;

void getRev(int bit) {			//二进制翻转函数
	for (int i = 0; i < (1<<bit); i++)	//高位决定二进制大小
		rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit-1));
}								//能保证(x>>1)<x,满足递推性质

void FFT(CD* a,int n,int DFT) {	//快速傅立叶变换的过程, DFT参数是1则是正变换,-1则是逆离散傅氏变换.
	for (int i = 0; i < n; i++) {//按照二进制反转
		if (i<rev[i])			//保证只把前面的数和后面的数交换,(否则数组会被翻回来)
			swap(a[i],a[rev[i]]);
	}
	for (int step = 1; step < n; step<<=1) {	//枚举步长的一半
		CD omg_n = exp(CD(0,DFT*PI/step));		//计算单位复根
		for (int j = 0; j < n; j+=step<<1) {	//对于每一块
			CD omg_nk(1,0);						//!!每一块都是独立的序列,都是以零次方位为起始的
			for (int k = j; k < j+step; k++) {	//蝴蝶操作,蝶形操作的这一块
				CD x = a[k];
				CD y = omg_nk*a[k+step];
				a[k] = x+y;
				a[k+step] = x-y;
				omg_nk*=omg_n;					//计算下一次的复根
			}
		}
	}
	if (DFT == -1) {							//IDFT逆离散傅氏变换,就要将序列除以n
		for (int i = 0; i < n; i++) 
			a[i]/=n;
	}
}
inline int read(){
    int f=1,x=0;char ch;
    do{ch=getchar();if(ch=='-')f=-1;}while(ch<'0' || ch>'9');
    do{x=x*10+ch-'0';ch=getchar();}while(ch>='0'&& ch<='9');
    return f*x;
}

int main(int argc, char const *argv[]) {
	scanf("%d%d", &n,&m);
	for (int i = 0; i <= n; i++)
		a[i] = read();
	for (int i = 0; i <= m; i++)
		b[i] = read();
	int bit = 1,s=2;
	for (bit=1;(1<<bit)<n+m-1;bit++)
		s<<=1;
	getRev(bit);
	FFT(a,n,1);FFT(b,n,1);
	for (int i = 0; i <= n; i++)	a[i] = a[i]*b[i];
	FFT(a,n,-1);
	for (int i = 0; i <= m; i++)	printf("%d ",(int)a[i].real()/n+0.5);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/SuperBvs/article/details/87900070