浅谈FFT&NTT

复数及单位根

复数的定义大概就是:\(i^2=-1\),其中\(i\)就是虚数单位。

那么,在复数意义下,对于方程:
\[ x^n=1 \]
就必定有\(n\)个解,这\(n\)个解的分布一定是在复平面上,以圆点为圆心,半径为\(1\)的圆的\(n\)等分点。

由于欧拉公式:
\[ e^{i\theta}=\cos\theta+i\cdot \sin\theta \]
\(2\pi\)带入:
\[ e^{2i\pi}=1 \]
比较一下这个和上面的方程,设:
\[ \omega_n=e^{2i\pi/n} \]
那么可以得到上面方程的\(n\)个解分别为:
\[ \forall i\in[0,n-1],x_i=\omega_n^i \]
那么,我们称这\(n\)个解为\(n\)次单位根。

关于单位根,有以下性质:

\[ \omega_n^x=-\omega_n^{x+\frac{n}{2}},w_n^2=w_{\frac{n}{2}} \]
这些性质的证明都很简单。

点值表达式

考虑到,一个多项式可以看做是一个\(n​\)次的函数,如果已知这个函数的\(n+1​\)个点,那么就可以确定这个多项式。

任取\(n+1\)个不同的数\(x_i\),知道了多项式的结果\(F(x_i)\),这个称作多项式的点值表达式

离散傅里叶变换(Discrete Fourier Transform, DFT)

对于一个\(n-1​\)次多项式,取\(n​\)个数\(w_n^0,w_n^1...w_n^{n-1}​\),得到一个点值表达式,称作离散傅里叶变换

先把这个多项式凑成\(n=2^x\)的形式,高位补\(0\)

对于\(F(\omega_n^{k})​\),显然可以得到:
\[ F(\omega_n^k)=\sum_{i=0}^{n-1}(\omega_n^k)^i\cdot A_i \]
其中\(A_i\)为系数。

然后对这个进行奇偶分类,可得:
\[ \begin{align} F(\omega_n^k)&=\sum_{i=0}^{n/2-1}(\omega_n^{k})^{2i}\cdot A_{2i}+\sum_{i=0}^{n/2-1}(\omega_n^k)^{2i+1}\cdot A_{2i+1}\\ &=\sum_{i=0}^{n/2-1}(\omega_{n/2}^{k})^i\cdot A_{2i}+\omega_n^k\cdot \sum_{i=0}^{n/2-1}(\omega_{n/2}^k)^{i}\cdot A_{2i+1} \end{align} \]

\(F_0(x)\)为偶数项的系数构成的多项式,\(F_1(x)\)为奇数项,这个显然是一个子问题。

那么:
\[ F(\omega_n^k)=F_0(\omega_{n/2}^k)+w_n^k\cdot F_1(\omega_{n/2}^k) \]
所以,令\(k\leqslant n/2\),则有:
\[ F(\omega_n^{k+n/2})=F_0(\omega_{n/2}^k)+w_n^{k+n/2}\cdot F_1(\omega_{n/2}^k) \]
即:
\[ F(\omega_n^{k})=F_0(\omega_{n/2}^k)+w_n^{k}\cdot F_1(\omega_{n/2}^k) \\F(\omega_n^{k+n/2})=F_0(\omega_{n/2}^k)-w_n^{k}\cdot F_1(\omega_{n/2}^k) \]
递归计算即可,复杂度:
\[ T(n)=2 \cdot T(\frac{n}{2})+O(n)=O(n\log n) \]

离散傅里叶逆变换(Inverse Discrete Fourier Transform, IDFT)

对于离散傅里叶变换,写成矩阵的形式就是:
\[ \begin{bmatrix} (\omega_n^0)^0&(\omega_n^0)^1&\cdots & (\omega_n^0)^{n-1}\\ (\omega_n^1)^0&(\omega_n^1)^1&\cdots & (\omega_n^1)^{n-1}\\ \vdots&\vdots&\ddots&\vdots\\ (\omega_n^{n-1})^0&(\omega_n^{n-1})^1&\cdots & (\omega_n^{n-1})^{n-1}\\ \end{bmatrix} \times \begin{bmatrix} A_0\\A_1\\\vdots\\A_{n-1} \end{bmatrix} = \begin{bmatrix} F(\omega_n^0)\\F(\omega_n^1)\\\vdots\\F(\omega_n^{n-1}) \end{bmatrix} \]
现在,我们是知道了等号右边的\(F​\),要求等号左边的\(A​\)

设上面的系数矩阵为\(s\),考虑下面这个矩阵,设为\(t\)
\[ t=\begin{bmatrix} (\omega_n^{-0})^0&(\omega_n^{-0})^1&\cdots & (\omega_n^{-0})^{n-1}\\ (\omega_n^{-1})^0&(\omega_n^{-1})^1&\cdots & (\omega_n^{-1})^{n-1}\\ \vdots&\vdots&\ddots&\vdots\\ (\omega_n^{-(n-1)})^0&(\omega_n^{-(n-1)})^1&\cdots & (\omega_n^{-(n-1)})^{n-1}\\ \end{bmatrix} \]
考虑矩阵\(v=t\times s\)

对于\(v_{i,j}​\),根据矩阵乘法规则,它会等于:
\[ v_{i,j}=\sum_{k=0}^{n-1}(\omega_n^{-i})^{k}\cdot (\omega_{n}^{k})^{j}=\sum_{k=0}^{n-1}\omega_n^{k(j-i)} \]
\(i=j\),则:
\[ v_{i,j}=n \]
否则:
\[ v_{i,j}=\sum_{k=0}^{n-1}\omega_n^{k(j-i)}=\frac{1-(\omega_n^{j-i})^n}{1-\omega_n^{j-i}} \]
注意到:
\[ \omega_n^n=0 \]
所以:
\[ v_{i,j}=0 \]
然后把这个矩阵写出来:
\[ v=\begin{bmatrix} n&0&\cdots&0\\ 0&n&\cdots&0\\ \vdots&\vdots&\ddots&\vdots\\ 0&0&\cdots&n \end{bmatrix} \]
然后可以发现,这个就是单位矩阵的\(n\)倍,即:
\[ t\times s=n\cdot \epsilon \]
然后考虑第一个矩阵的式子,等式两边同时左乘一个\(t\),可得:
\[ n\cdot \begin{bmatrix} A_0\\A_1\\\vdots\\A_{n-1} \end{bmatrix} = \begin{bmatrix} (\omega_n^{-0})^0&(\omega_n^{-0})^1&\cdots & (\omega_n^{-0})^{n-1}\\ (\omega_n^{-1})^0&(\omega_n^{-1})^1&\cdots & (\omega_n^{-1})^{n-1}\\ \vdots&\vdots&\ddots&\vdots\\ (\omega_n^{-(n-1)})^0&(\omega_n^{-(n-1)})^1&\cdots & (\omega_n^{-(n-1)})^{n-1}\\ \end{bmatrix} \times \begin{bmatrix} F(\omega_n^0)\\F(\omega_n^1)\\\vdots\\F(\omega_n^{n-1}) \end{bmatrix} \]
所以,\(IDFT​\)的时候直接照搬\(DFT​\),然后把\(\omega_n^k​\)改成\(\omega_n^{-k}​\),最后在除个\(n​\)就好了。

迭代实现

由于上面的递归实现常数过大,不是很优秀,这里有一种迭代的实现方法。

考虑我们把递归过程改成迭代,那么显然我们需要把顺序重新排列一下,然后每次把相邻的\(2^k\)个数合并就好了。

\(n=2^m\),考虑第\(i\)次递归的时候,二进制下第\(i\)\(0\)的放左边,为\(1\)的放右边,那么可以发现,左边的所有数新位置的编号第\(m-i+1\)位都为\(0\),右边的为\(1\),这个可以自己画下图理解下。

那么,设\(rev(x)\)表示把\(x\)的二进制翻转的结果,即第\(i\)位和第\(m-i+1\)位交换。

对于原序列第\(i​\)个数,他在新序列的位置就应该是\(rev(i)​\)

代码就比较好写了:

#include<cmath>
#include<cstdio>
#include<iostream>
#include<algorithm>
using namespace std;
 
void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
 
void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

const int maxn = 4e6+10;

#define lf double

const lf pi = acos(-1);

struct complex {
    lf real,imag;
    complex () {}
    complex (lf _real,lf _imag) {real=_real,imag=_imag;}
    complex conj() {return complex(real,-imag);}  //共轭复数
    complex operator = (const int &rhs) {real=rhs;return *this;}
    complex operator + (const complex &rhs) const {return complex(real+rhs.real,imag+rhs.imag);}
    complex operator - (const complex &rhs) const {return complex(real-rhs.real,imag-rhs.imag);}
    complex operator * (const complex &rhs) const {return complex(real*rhs.real-imag*rhs.imag,imag*rhs.real+real*rhs.imag);}
};   //手写的一个复数类

complex es[maxn],ces[maxn],a[maxn],b[maxn];
int n,m,N,pos[maxn],bit;

void init() {
    for(int i=0;i<N;i++) es[i]=complex(cos(2*pi/N*i),sin(2*pi/N*i));
    for(int i=0;i<N;i++) ces[i]=es[i].conj();  //预处理单位根
    for(int i=1;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));  //pos[x]表示rev(x)
}

void fft(complex *r,complex *w) {
    for(int i=0;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);  //调整位置
    for(int i=1;i<N;i<<=1) 
        for(int j=0;j<N;j+=(i<<1))
            for(int k=0;k<i;k++) {
                complex x=r[j+k],y=w[N/(i<<1)*k]*r[j+k+i];  //迭代实现
                r[j+k]=x+y,r[i+j+k]=x-y;
            }
}

int main() {
    read(n),read(m);
    for(int i=0,x;i<=n;i++) read(x),a[i]=x;
    for(int i=0,x;i<=m;i++) read(x),b[i]=x;
    N=1;while(N<=n+m) N<<=1,bit++;
    init();fft(a,es),fft(b,es);
    for(int i=0;i<=N;i++) a[i]=a[i]*b[i];fft(a,ces);
    for(int i=0;i<=n+m;i++) printf("%d ",(int)(a[i].real/N+0.5));puts("");  //记得答案要除N,这个其实应该写在fft函数里面。。
    return 0;
}

这份代码在洛谷的模板P3803 【模板】多项式乘法(FFT)提交可以通过。

快速数论变换(Fast Number-Theoretic Transform,FNT)

这玩意其实一般叫做\(NTT\)

考虑到上面\(FFT\)的过程用到了单位根的哪些性质:

  1. \(\omega_n^0,\omega_n^1...\omega_n^{n-1}\)互不相同,这保证了点值表达式可以成立。
  2. \(\omega_n^2=\omega_{n/2}\)\(\omega_n^{k+n/2}=-\omega_n^k\)
  3. \(\omega_n^n=1\),这保证了IDFT的正确性。

对于模数\(p=k\cdot 2^s+1\),且\(p\)为质数,设它的原根为\(g\),那么我们可以令\(\omega_n=g^{(p-1)/n}\)

由于原根的性质,第一条显然是满足的。

对于第二条:
\[ \omega_n^2=g^{2(p-1)/n}=g^{(p-1)/(n/2)}=\omega_{n/2} \]
并且:
\[ \omega_n^{n/2}=g^{(p-1)/2}=-1 \]
也比较显然。

对于第三点,其实就是费马小定理,显然满足,所以我们可以用这个来替代\(\omega_n\),进行数论变换,代码也差不多。

注意,对于质数\(p=k\cdot 2^s+1\),它能处理的数据范围是\(n\leqslant 2^s\)

模板:题目和上题相同

#include<bits/stdc++.h>
using namespace std;
 
void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
 
void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

const int maxn = 4e6+10;
const int mod = 998244353;

int n,m,N=1,bit,pos[maxn],es[maxn],ces[maxn],a[maxn],b[maxn];

int qpow(int aa,int x) {
    int res=1;
    for(;x;x>>=1,aa=1ll*aa*aa%mod) if(x&1) res=1ll*res*aa%mod;
    return res;
}

void ntt(int *r,int f) {
    for(int i=0;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
    for(int i=1;i<N;i<<=1) {
        int wn=qpow(f==1?3:qpow(3,mod-2),(mod-1)/(i<<1));
        for(int j=0,w=1;j<N;j+=(i<<1),w=1) 
            for(int k=0;k<i;k++,w=1ll*w*wn%mod) {
                int x=r[j+k],y=1ll*w*r[i+j+k]%mod;
                r[j+k]=(x+y)%mod,r[i+j+k]=(x-y)%mod;
            }
    }
}

int main() {
    read(n),read(m);
    for(int i=0;i<=n;i++) read(a[i]);
    for(int i=0;i<=m;i++) read(b[i]);
    while(N<=n+m) N<<=1,bit++;
    for(int i=0;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));
    ntt(a,1),ntt(b,1);
    for(int i=0;i<=N;i++) a[i]=1ll*a[i]*b[i]%mod;
    ntt(a,-1);int inv=qpow(N,mod-2);
    for(int i=0;i<=n+m;i++) printf("%d ",((1ll*a[i]*inv%mod)+mod)%mod);puts("");
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/hbyer/p/10325916.html
今日推荐