算法学习FFT系列(4):任意模数的快速傅里叶变换(MTT)

算法学习FFT系列(4):任意模数的快速傅里叶变换(MTT)

毛神好强,毛神好强

任意模数的快速傅里叶变换的解法

这里假设序列是 10 5 级别的,模数是 10 9 级别的。
首先找到问题的瓶颈。
由于模数的任意性,所以NTT失效了。
而由于如果没有模数,最后的结果是在 10 23 级别,会太大,所以FFT也不适用。
既然找到了瓶颈,所以我们解决的方针就有两种:(1)把NTT推广到任意模数的形式。(2)通过某种方法分步计算FFT使得精度符合要求。

三模数NTT

根据第一种思路,找到三个符合要求的模数在这三个模数意义下分别FFT然后利用中国剩余定理合并一下。
以后有空再补吧,今天主要介绍另外一种。

拆系数FFT

讲每个数拆成 k M + b 的形式,其中 M 是常数
考虑卷积的过程。
( k 1 M + b 1 ) ( k 2 M + b 2 ) = k 1 k 2 M 2 + b 1 k 2 M + b 2 k 1 M + b 1 b 2
M 这个东西可以FFT出来之后乘
这样子的话,大概估计一下范围,考虑最大的情况。
k 1 k 2 = P M P M = P 2 M 2
b 1 k 2 ( b 2 k 1 ) = M P M
b 1 b 2 = M 2
不难发现,当 M = P 的时候,这些东西都是 P 级别的。
所以FFT出来的结果是 10 14 级别的。
这样子的话,我们将整个序列拆成4个序列。做4次DFT,3次IDFT即可。
但是这样有两个坏处,一个是精度不行,还有一个就是7的常数,如果加上FFT本身的常数可能快两个log了。

DFT合并和IDFT合并

这是一个神奇的优化常数的技巧。
最早是Codeforces上的神仙提出的,毛神的论文里有
具体的方法是构造共轭式,我们定义
P ( x ) = A ( x ) + i B ( x ) , Q ( x ) = A ( x ) i B ( x )
为了方便,令 X = 2 π j k L 考虑DFT后 P ( ω k ) Q ( ω k ) 的关系
D F T ( p k ) = j = 0 L 1 ( a j + i b j ) ω j k
= j = 0 L 1 ( a j + i b j ) ( c o s X + i s i n X )
= j = 0 L 1 ( a j c o s X b j s i n X ) + i ( b j c o s X + a j s i n X )
Q ( ω k ) = j = 0 L 1 ( a j i b j ) ω j k
= j = 0 L 1 ( a j i b j ) ( c o s X + i s i n X )
= j = 0 L 1 ( a j c o s X + b j s i n X ) + i ( a j s i n X b j c o s X )
= j = 0 L 1 ( a j c o s ( X ) b j s i n ( X ) ) i ( b j c o s ( X ) + a j s i n ( X ) )
= c o n j ( P ( ω k ) ) = c o n j ( P ( ω L k ) )
秀了一波推到,我们发现只要DFT出P,我们就能得到Q的DFT
然后
A ( ω k ) = P ( ω k ) + Q ( ω k ) 2
B ( ω k ) = i Q ( ω k ) P ( ω k ) 2
这样子的话,两次可以优化到一次。
那我们考虑IDFT,其实只要逆回去就行了。
M ( ω k ) = A ( ω k ) + i B ( ω k ) = P ( ω k ) + Q ( ω k ) 2 Q ( ω k ) P ( ω k ) 2 = P ( ω k )
I D F T ( P ( ω k ) ) = I D F T ( M ( ω k ) )
由于IDFT前后都是实数,直接把实部和虚部掏出来即可。
这样子的话,我们可以把两个IDFT合并在一起搞。
这样子的话,MTT就变成了4次DFT了。

代码

//luoguP4245 【模板】MTT 
#include<cstdio>
#include<cmath>
#include<algorithm>
const int N = 262144 + 10, M = 32767;
const double pi = acos(-1.0);
typedef long long LL;
int read() {
    char ch = getchar(); int f = 1, x = 0;
    for(;ch < '0' || ch > '9'; ch = getchar()) if(ch == '-') f = -1;
    for(;ch >= '0' && ch <= '9'; ch = getchar()) x = (x << 1) + (x << 3) - '0' + ch;
    return x * f;
}
struct cp {
    double r, i;
    cp(double _r = 0, double _i = 0) : r(_r), i(_i) {}
    cp operator * (const cp &a) {return cp(r * a.r - i * a.i, r * a.i + i * a.r);}
    cp operator + (const cp &a) {return cp(r + a.r, i + a.i);}
    cp operator - (const cp &a) {return cp(r - a.r, i - a.i);}
}w[N], nw[N], da[N], db[N];
cp conj(cp a) {return cp(a.r, -a.i);}
int L, n, m, a[N], b[N], c[N], R[N], P;
void Pre() {
    int x = 0; for(L = 1; (L <<= 1) <= n + m; ++x) ;
    for(int i = 1;i < L; ++i) R[i] = (R[i >> 1] >> 1) | (i & 1) << x;
    for(int i = 0;i < L; ++i) w[i] = cp(cos(2 * pi * i / L), sin(2 * pi * i / L));
}
void FFT(cp *F) {
    for(int i = 0;i < L; ++i) if(i < R[i]) std::swap(F[i], F[R[i]]);
    for(int i = 2, d = L >> 1;i <= L; i <<= 1, d >>= 1) 
        for(int j = 0;j < L; j += i) {
            cp *l = F + j, *r = F + j + (i >> 1), *p = w, tp;
            for(int k = 0;k < (i >> 1); ++k, ++l, ++r, p += d) 
                tp = *r * *p, *r = *l - tp, *l = *l + tp;
        }
}
void Mul(int *A, int *B, int *C) {
    for(int i = 0;i < L; ++i) (A[i] += P) %= P, (B[i] += P) %= P;
    static cp a[N], b[N], Da[N], Db[N], Dc[N], Dd[N];
    for(int i = 0;i < L; ++i) a[i] = cp(A[i] & M, A[i] >> 15);
    for(int i = 0;i < L; ++i) b[i] = cp(B[i] & M, B[i] >> 15);
    FFT(a); FFT(b);
    for(int i = 0;i < L; ++i) {
        int j = (L - i) & (L - 1); static cp da, db, dc, dd;
        da = (a[i] + conj(a[j])) * cp(0.5, 0);
        db = (a[i] - conj(a[j])) * cp(0, -0.5);
        dc = (b[i] + conj(b[j])) * cp(0.5, 0);
        dd = (b[i] - conj(b[j])) * cp(0, -0.5);
        Da[j] = da * dc; Db[j] = da * dd; Dc[j] = db * dc; Dd[j] = db * dd; //顺便区间反转,方便等会直接用DFT代替IDFT 
    }
    for(int i = 0;i < L; ++i) a[i] = Da[i] + Db[i] * cp(0, 1);
    for(int i = 0;i < L; ++i) b[i] = Dc[i] + Dd[i] * cp(0, 1);
    FFT(a); FFT(b);
    for(int i = 0;i < L; ++i) {
        int da = (LL) (a[i].r / L + 0.5) % P; //直接取实部和虚部 
        int db = (LL) (a[i].i / L + 0.5) % P;
        int dc = (LL) (b[i].r / L + 0.5) % P;
        int dd = (LL) (b[i].i / L + 0.5) % P;
        C[i] = (da + ((LL)(db + dc) << 15) + ((LL)dd << 30)) % P; 
    }
}
int main() {
    n = read(); m = read(); P = read();
    for(int i = 0;i <= n; ++i) a[i] = read();
    for(int j = 0;j <= m; ++j) b[j] = read();
    Pre(); Mul(a, b, c); 
    for(int i = 0;i <= n + m; ++i) printf("%d ", (c[i] + P) % P); puts("");
    return 0;
}

猜你喜欢

转载自blog.csdn.net/lvzelong2014/article/details/80156989