任意模数NTT学习笔记

这两天有点颓,所以东西学的也很慢。。。这个一眼就能推出来的活生生卡了我两天。。

说几个细节:

柿子:

\[f*g = (\frac{f}{M} +f\%m)*(\frac{g}{M} +g\%m) \]

\(M\)通常设置为\(32768\)。把上一步的几个韩束化成\(a,b,c,d\)的形式,答案就是:

  • \(M * M * a * c+M * (a * d + b *c) +b * d\)

一看卷积,多搞几次\(FFT\)就过去了。

  • 处处小心膜爆。有效方法如下。

    • #define int long long
      #define double long double
      int add (int x, int y) {x %= p, y %= p; return (((x + y) % p) + p) % p;}
      int mul (int x, int y) {x %= p, y %= p; return (((x * y) % p) + p) % p;}
  • 数组记得开大。\(N<<2\)左右。

#include <bits/stdc++.h>
using namespace std;

#define int long long
#define double long double

const int N = 600010;
const double pi = acos (-1);

int n, m, p, lim = 1, f[N], g[N], rev[N], res[N];

struct Complex {
    double x, y;
    Complex (double _x = 0, double _y = 0) {x = _x, y = _y;}
    Complex operator + (Complex rhs) {return Complex (x + rhs.x, y + rhs.y);}
    Complex operator - (Complex rhs) {return Complex (x - rhs.x, y - rhs.y);}
    Complex operator * (Complex rhs) {return Complex (x * rhs.x - y * rhs.y, x * rhs.y + y * rhs.x);}
}a[N], b[N], c[N], d[N], t1[N], t2[N], t3[N], t4[N];


const int M = 32768;
signed add (int x, int y) {x %= p, y %= p; return (((x + y) % p) + p) % p;}
signed mul (int x, int y) {x %= p, y %= p; return (((x * y) % p) + p) % p;}

void fast_fast_tle (Complex *A, int type) {
    for (int i = 0; i < lim; ++i) {
        if (i < rev[i]) {
            swap (A[i], A[rev[i]]);
        }
    }
    for (int mid = 1; mid < lim; mid <<= 1) {
        Complex Wn (cos (pi / mid), type * sin (pi / mid));
        for (int p = 0; p < lim; p += (mid << 1)) {
            Complex w (1, 0);
            for (int i = 0; i < mid; ++i, w = w * Wn) {
                Complex x = A[p + i], y = w * A[p + i + mid];
                A[p + i] = x + y;
                A[p + i + mid] = x - y;
            }
        }
    }
    if (type == -1) {
        for (int i = 0; i < lim; ++i) {
            A[i].x /= lim;
        }
    }
}

signed main () {
    cin >> n >> m >> p;
    for (int i = 0; i <= n; ++i) {
        cin >> f[i];
        a[i].x = f[i] / M;
        b[i].x = f[i] % M;
    }
    for (int i = 0; i <= m; ++i) {
        cin >> g[i];
        c[i].x = g[i] / M;
        d[i].x = g[i] % M;
    }
    while (lim <= n + m) lim <<= 1;
    for (int i = 0; i < lim; ++i) {
        rev[i] = (rev[i >> 1] >> 1) + (i & 1) * (lim / 2);
    }
    fast_fast_tle (a, 1), fast_fast_tle (b, 1); fast_fast_tle (c, 1), fast_fast_tle (d, 1);
    for (int i = 0; i < lim; ++i) {
        t1[i] = a[i] * c[i], t2[i] = a[i] * d[i], t3[i] = b[i] * c[i], t4[i] = b[i] * d[i];
    }
    fast_fast_tle (t1, -1), fast_fast_tle (t2, -1); fast_fast_tle (t3, -1), fast_fast_tle (t4, -1);
    for (int i = 0; i < lim; ++i) {
        res[i] = add (res[i], mul (mul (M, M), (int) (t1[i].x + 0.1)));
        res[i] = add (res[i], mul (mul (M, 1), (int) (t2[i].x + 0.1)));
        res[i] = add (res[i], mul (mul (M, 1), (int) (t3[i].x + 0.1)));
        res[i] = add (res[i], mul (mul (1, 1), (int) (t4[i].x + 0.1)));
    }
    for (int i = 0; i <= n + m; ++i) printf ("%lld ", res[i]);
}

猜你喜欢

转载自www.cnblogs.com/maomao9173/p/10610805.html