【模板】任意模数NTT(中国剩余定理版,O(1)long long乘)

任意模的NTT,即题目给定要求的取模的树MOD!=a2^k+1的形式,或者2^k小于需要的数值

如果假定NTT作用长度为len,系数的值不大于x,则相乘后系数不大于len*x^2.

如果我们取合适的多个模数p[](他们有相同的原根),使p_1p_2p_3>len*x^2,同时我们得到分别以p_1,p_2,p_3为模的NTT作用系数C[1][],C[2][],C[3][],我们可以得到实际系数x满足:

\\ x[i] \equiv C[1][i] (mod\ p_1)\\ x[i] \equiv C[2][i] (mod\ p_2)\\ x[i] \equiv C[3][i] (mod\ p_3),由中国剩余定理通解x[i] = X_0+kp_1p_2p_3,而x[i] \leq len*x^x < p_1p_2p_3 \ => x[i]=X_0

但这里由于p_i的选取,p_1p_2p_3会爆long\ long,故可以先求解x[i] \equiv x_0 (mod\ p_1p_2),\ x[i]=x_0+k_1p_1p_2=C[3][i]+k_3p_3

k_1p_1p_2 \equiv C[3][i]-x_0(mod\ p_3),求出k_1后即可得到X_0

这里有个神奇的O(1)long\ long乘法,根据的是A\%B=A-B*\lfloor\frac{A}{B}\rfloor,以及溢出后减法在模意义下的等价(不太确定 (..•˘_˘•..))

LL multi(LL a, LL b, LL mod){
  a %= mod, b %= mod;
  return ((a * b - (LL)((LL)((long double)a / mod * b + 1e-3) * mod)) % mod + mod) % mod;
}
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
using namespace std;

using LL = long long;
const int MAXN = 4e5 + 5;
const LL p[] = {0, 469762049ll, 998244353ll, 1004535809ll}, g = 3, MX = 469762049ll * 998244353ll;
LL N, M, MOD;
LL A[MAXN], B[MAXN], C[4][MAXN], F[MAXN], G[MAXN], ans[MAXN];
LL qpow(LL, int, LL);
void DNT(int len, LL *a, int type, LL mod);
void getC(int, int);
LL multi(LL, LL, LL);

int main(){
  ios::sync_with_stdio(false);
  cin >> N >> M >> MOD;
  int i;
  for(i = 0; i <= N; i++) cin >> A[i];
  for(i = 0; i <= M; i++) cin >> B[i];
  int len = 1;
  while(len <= M + N) len <<= 1;
  getC(1, len), getC(2, len), getC(3, len);
  for(i = 0; i <= N + M; i++){
    LL x, k1;
    x = (multi(C[1][i] * p[2] % MX, qpow(p[2] % p[1], p[1] - 2, p[1]), MX) +
         multi(C[2][i] * p[1] % MX, qpow(p[1] % p[2], p[2] - 2, p[2]), MX)) % MX;
    k1 = (multi((C[3][i] % p[3] - x % p[3] + p[3]) % p[3], qpow(MX % p[3], p[3] - 2, p[3]), p[3]));
    ans[i] = ((k1 % MOD) * (MX % MOD) + x % MOD) % MOD;
    cout << ans[i] << " ";
  }
  return 0;
}

void getC(int I, int len){
  int i;
  memset(F, 0, sizeof(F)), memset(G, 0, sizeof(G));
  for(i = 0; i <= N; i++) F[i] = A[i];
  for(i = 0; i <= M; i++) G[i] = B[i];
  DNT(len, F, 1, p[I]), DNT(len, G, 1, p[I]);
  for(i = 0; i <= len; i++)
    C[I][i] = F[i] * G[i] % p[I];
  DNT(len, C[I], -1, p[I]);
}

void bit_reverse(int, LL *);
void DNT(int len, LL *a, int type, LL mod){
  bit_reverse(len, a);
  int i, j, l;
  for(l = 2; l <= len; l <<= 1){
    int mid = l >> 1;
    LL wn = qpow(g, (mod - 1) / l, mod);
    if(type == -1) wn = qpow(wn, mod - 2, mod);
    for(i = 0; i < len; i += l){
      LL w = 1;
      for(j = 0; j < mid; j++, w = w * wn % mod){
        LL x = a[i + j], y = w * a[i + j + mid] % mod;
        a[i + j] = (x + y) % mod;
        a[i + j + mid] = (x - y + mod) % mod;
      }
    }
  }

  if(type == -1){
    int inv = qpow(len, mod - 2, mod);
    for(i = 0; i <= len; i++) a[i] = a[i] * inv % mod;
  }
}

void bit_reverse(int len, LL *a){
  int i, j, k;
  for(i = 0, j = 0; i < len; i++){
    if(i > j) swap(a[i], a[j]);
    for(k = len >> 1; (j & k); j ^= k, k >>= 1);
    j ^= k;
  }
}

LL qpow(LL x, int n, LL mod){
  LL res = 1;
  while(n){
    if(n & 1) res = res * x % mod;
    x = x * x % mod;
    n >>= 1;
  }
  return res;
}

LL multi(LL a, LL b, LL mod){
  a %= mod, b %= mod;
  return ((a * b - (LL)((LL)((long double)a / mod * b + 1e-3) * mod)) % mod + mod) % mod;
}

猜你喜欢

转载自blog.csdn.net/Hardict/article/details/82718218