拆系数FFT(任意模数FFT)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/oi_Konnyaku/article/details/84990404

拆系数FFT

对于任意模数 m o d mod
m = m o d m=\sqrt {mod}
把多项式 A ( x ) A(x) B ( x ) B(x) 的系数都拆成 a × m + b a\times m+b 的形式,时 a , b a, b 都小于 m m
提出,那么一个多项式就可以拆成两个多项式的加法
一个是 a m a*m 的,一个是 b b
直接乘法分配律, a a aa 一遍, a b ab 一遍, b a ba b b bb 一遍,四遍 F F T FFT
乘出来不会超过取模范围
然后合并直接
( a × m + b ) ( c × m + d ) = a × c × m 2 + ( a × c + b × d ) m + b × d (a\times m+b)(c\times m+d)=a\times c\times m^2+(a\times c+b\times d)m+b\times d
这样子要进行 7 7 D F T DFT

如果研究一下 m y y myy 2016 2016 年的集训队论文就会发现有 2 2 次 或者 1.5 1.5 D F T DFT F F T FFT 算法
2次的够了吧
m y y myy 巧妙的运用了复数的虚部,优化了算法
具体来说


C ( x ) = A ( x ) + i B ( x ) C(x)=A(x)+iB(x)
D ( x ) = A ( x ) i B ( x ) D(x)=A(x)-iB(x)

假设
c ( w n k ) c(w_n^k) 表示将 C ( x ) C(x) D F T DFT 后的点值
d ( w n k ) d(w_n^k) 表示将 D ( x ) D(x) D F T DFT 后的点值
w w n n 次单位复数根
c o n j ( x ) conj(x) 表示 x x 的共轭复数
那么
c ( w 2 n k ) = A ( w 2 n k ) + i B ( w 2 n k ) = j = 0 2 n 1 A j w 2 n j k + i B j w 2 n j k c(w_{2n}^{k})=A(w_{2n}^{k})+iB(w_{2n}^{k})=\sum_{j=0}^{2n-1}A_jw_{2n}^{jk}+iB_jw_{2n}^{jk}
= j = 0 2 n 1 ( A j + i B j ) ( c o s π j k n + i s i n π j k n ) =\sum_{j=0}^{2n-1}(A_j+iB_j)(cos\frac{\pi jk}{n}+isin \frac{\pi jk}{n})
d ( w 2 n k ) = A ( w 2 n k ) i B ( w 2 n k ) = j = 0 2 n 1 ( A j i B j ) ( c o s π j k n + i s i n π j k n ) d(w_{2n}^{k})=A(w_{2n}^{k})-iB(w_{2n}^{k})=\sum_{j=0}^{2n-1}(A_j-iB_j)(cos\frac{\pi jk}{n}+isin \frac{\pi jk}{n})
x = π j k n x=\frac{\pi jk}{n}
那么
d ( w 2 n k ) = j = 0 2 n 1 ( A j c o s x + B j s i n x ) + i ( A j s i n x B j c o s x ) d(w_{2n}^{k})=\sum_{j=0}^{2n-1}(A_jcosx+B_jsinx)+i(A_jsinx-B_jcosx)
= c o n j ( j = 0 2 n 1 ( A j c o s x + B j s i n x ) i ( A j s i n x B j c o s x ) ) =conj(\sum_{j=0}^{2n-1}(A_jcosx+B_jsinx)-i(A_jsinx-B_jcosx))
= c o n j ( j = 0 2 n 1 ( A j c o s ( x ) B j s i n ( x ) ) + i ( A j s i n ( x ) + B j c o s ( x ) ) ) =conj(\sum_{j=0}^{2n-1}(A_jcos(-x)-B_jsin(-x))+i(A_jsin(-x)+B_jcos(-x)))
= c o n j ( j = 0 2 n 1 ( A j + i B j ) ( c o s ( x ) + i s i n ( x ) ) ) =conj(\sum_{j=0}^{2n-1}(A_j+iB_j)(cos(-x)+isin(-x)))
= c o n j ( j = 0 2 n 1 ( A j + i B j ) ( c o s ( 2 π x ) + i s i n ( 2 π x ) ) ) =conj(\sum_{j=0}^{2n-1}(A_j+iB_j)(cos(2\pi-x)+isin(2\pi-x)))
= c o n j ( c ( w 2 n 2 n k ) ) =conj(c(w_{2n}^{2n-k}))

也就是说
只需要一次 D F T DFT 求出 c c 就可以求出 d d
那么
D F T ( A ( k ) ) = c ( w k ) + d ( w k ) 2 DFT(A(k))=\frac{c(w^{k})+d(w^{k})}{2}
D F T ( B ( k ) ) = c ( w k ) d ( w k ) 2 i DFT(B(k))=\frac{c(w^{k})-d(w^{k})}{2i}
再用一次 D F T 1 DFT^{-1} 还原出多项式,多项式乘法只要两次 D F T DFT

考虑拆系数 F F T FFT
D F T DFT 两两合并,从 4 4 次变成 2 2
D F T 1 DFT^{-1} 合并其中一个,从 3 3 次变成 2 2
一共 4 4

# include <bits/stdc++.h>
using namespace std;
typedef long long ll;

namespace IO {
    const int maxn((1 << 21) + 1);

    char ibuf[maxn], *iS, *iT, obuf[maxn], *oS = obuf, *oT = obuf + maxn - 1, c, st[65];
    int f, tp;
    
    char Getc() {
        return (iS == iT ? (iT = (iS = ibuf) + fread(ibuf, 1, maxn, stdin), (iS == iT ? EOF : *iS++)) : *iS++);
    }

    void Flush() {
        fwrite(obuf, 1, oS - obuf, stdout);
        oS = obuf;
    }

    void Putc(char x) {
        *oS++ = x;
        if (oS == oT) Flush();
    }
    
    template <class Int> void In(Int &x) {
        for (f = 1, c = Getc(); c < '0' || c > '9'; c = Getc()) f = c == '-' ? -1 : 1;
        for (x = 0; c <= '9' && c >= '0'; c = Getc()) x = (x << 3) + (x << 1) + (c ^ 48);
        x *= f;
    }
    
    template <class Int> void Out(Int x) {
        if (!x) Putc('0');
        if (x < 0) Putc('-'), x = -x;
        while (x) st[++tp] = x % 10 + '0', x /= 10;
        while (tp) Putc(st[tp--]);
    }
}

using IO :: In;
using IO :: Out;
using IO :: Putc;
using IO :: Flush;

const int maxn(1 << 18);
const double pi(acos(-1));

struct Complex {
    double a, b;

    inline Complex() {
        a = b = 0;
    }

    inline Complex(double _a, double _b) {
        a = _a, b = _b;
    }

    inline Complex operator +(Complex x) const {
        return Complex(a + x.a, b + x.b);
    }

    inline Complex operator -(Complex x) const {
        return Complex(a - x.a, b - x.b);
    }

    inline Complex operator *(Complex x) const {
        return Complex(a * x.a - b * x.b, a * x.b + b * x.a);
    }

    inline Complex Conj() {
        return Complex(a, -b);
    }
};

Complex a[maxn], b[maxn], w[maxn], a1[maxn], a2[maxn];
int r[maxn], l, deg, g[maxn], h[maxn], mod;

inline void FFT(Complex *p, int opt) {
    register int i, j, k, t;
    register Complex wn, x, y;
    for (i = 0; i < deg; ++i) if (r[i] < i) swap(p[r[i]], p[i]);
    for (i = 1; i < deg; i <<= 1)
        for(t = i << 1, j = 0; j < deg; j += t)
            for (k = 0; k < i; ++k) {
                wn = w[deg / i * k];
                if (opt == -1) wn.b *= -1;
                x = p[j + k], y = wn * p[i + j + k];
                p[j + k] = x + y, p[i + j + k] = x - y;
            }
}

inline void Mul(int n, int *p, int *q, int *f) {
    register int i, k, v1, v2, v3;
    register Complex ca, cb, da1, da2, db1, db2;
    for (deg = 1, l = 0; deg < n; deg <<= 1) ++l;
    for (i = 0; i < deg; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    for (i = 0; i < deg; ++i) w[i] = Complex(cos(pi * i / deg), sin(pi * i / deg));
    for (i = 0; i < n; ++i) a[i] = Complex(p[i] & 32767, p[i] >> 15), b[i] = Complex(q[i] & 32767, q[i] >> 15);
    for (FFT(a, 1), FFT(b, 1), i = 0; i < deg; ++i) {
        k = (deg - i) & (deg - 1), ca = a[k].Conj(), cb = b[k].Conj();
        da1 = (ca + a[i]) * Complex(0.5, 0), da2 = (a[i] - ca) * Complex(0, -0.5);
        db1 = (cb + b[i]) * Complex(0.5, 0), db2 = (b[i] - cb) * Complex(0, -0.5);
        a1[i] = da1 * db1 + (da1 * db2 + da2 * db1) * Complex(0, 1), a2[i] = da2 * db2;
    }
    for (FFT(a1, -1), FFT(a2, -1), i = 0; i < deg; ++i) {
        v1 = (ll)(a1[i].a / deg + 0.5) % mod, v2 = (ll)(a1[i].b / deg + 0.5) % mod;
        v3 = (ll)(a2[i].a / deg + 0.5) % mod, f[i] = (((ll)v3 << 30) + ((ll)v2 << 15) + v1) % mod;
        if (f[i] < 0) f[i] += mod;
    }
}

int main() {
    register int len, i, n, m;
    In(n), In(m), In(mod), ++n, ++m;
    for (i = 0; i < n; ++i) In(h[i]), h[i] %= mod;
    for (i = 0; i < m; ++i) In(g[i]), g[i] %= mod;
    for (len = 1, n += m - 1; len < n; len <<= 1);
    for (Mul(len, h, g, g), i = 0; i < n; ++i) Out(g[i]), Putc(' ');
    return Flush(), 0;
}

猜你喜欢

转载自blog.csdn.net/oi_Konnyaku/article/details/84990404
FFT