【模板】A*B Problem升级版(FFT快速傅里叶)

题目描述

给出两个 $n$ 位10进制数x和y,求x*y(详见 洛谷P1919

分析

假设已经学会了FFT/NTT。

高精度乘法只是多项式乘法的特殊情况,相当于$x=10$ 时。

例如n=3,求123*111

$$123 = x^2 + 2x + 3$$

$$111 = x^2 + x +1$$

$$\begin{aligned}123 * 111 &= (x^2 + 2x + 3)(x^2 + x +1)\\  &= x^4 + 3x^3 + 6x^2 + 5x + 3\\  &= 13653\end{aligned}$$

代码:

#include<bits/stdc++.h>
#define rg register
using namespace std;

typedef long long ll;
const int mod=998244353,g=3;
const int maxn = 6e4 + 10;

inline int qpow(int x,int k)
{
    int ans=1;
    while(k)
    {
        if(k&1)
            ans=(ll)ans*x%mod;
        x=(ll)x*x%mod,k>>=1;
    }
    return ans;
}

inline int module(int x,int y)
{
    x+=y;
    if(x>=mod)
        x-=mod;
    return x;
}

int rev[4*maxn];
inline void NTT(int*t,int lim,int type)
{
    for(rg int i=0;i<lim;++i)
        if(i<rev[i])
            swap(t[i],t[rev[i]]);
    for(rg int i=1;i<lim;i<<=1)
    {
        int gn=qpow(g,(mod-1)/(i<<1));
        if(type==-1)
            gn=qpow(gn,mod-2);
        for(rg int j=0;j<lim;j+=(i<<1))
        {
            int gi=1;
            for(rg int k=0;k<i;++k,gi=(ll)gi*gn%mod)
            {
                int x=t[j+k],y=(ll)gi*t[j+i+k]%mod;
                t[j+k]=module(x,y);
                t[j+i+k]=module(x,mod-y);
            }
        }
    }
    if(type==-1)
    {
        int inv=qpow(lim,mod-2);
        for(rg int i=0;i<lim;++i)
            t[i]=(ll)t[i]*inv%mod;
    }
}

int X[4*maxn],Y[4*maxn];
inline void mul(int*x, int*y, int n, int m)
{
    memset(X,0,sizeof(X));
    memset(Y,0,sizeof(Y));
    int lim = 1, L = 0;  //L=0必须写,局部变量默认值很可能不是0
    while(lim <= n + m) lim <<= 1, L++;   //lim为大于(n+m)的2的幂,所以最多需要4倍空间
    for(int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
    for(rg int i=0;i<lim;++i) X[i]=x[i],Y[i]=y[i];
    NTT(X,lim,1);
    NTT(Y,lim,1);
    for(rg int i=0;i<lim;++i) X[i]=(ll)X[i]*Y[i]%mod;
    NTT(X,lim,-1);
    for(rg int i=0;i<lim;++i) x[i]=X[i];
}


int n;
int a[4*maxn], b[4*maxn];
char s[maxn];

int main()
{
    scanf("%d", &n);
    scanf("%s", s);
    for(int i = 0;i < n;i++) a[i] = s[n-1-i] - '0';
    scanf("%s", s);
    for(int i = 0;i < n;i++)  b[i] = s[n-1-i] - '0';
    mul(a, b, n, n);

//    for(int i = 0;i < 2*n;i++)  printf("%d ", a[i]);
//    printf("\n");

    int tmp = 0;    //进位
    for(int i = 0;i < 2*n;i++)  //
    {
        a[i] = a[i] + tmp;
        tmp = a[i] / 10;
        a[i] = a[i] % 10;

    }

//    for(int i = 0;i < 2*n;i++)  printf("%d ", a[i]);
//    printf("\n");

    bool flag = true;
    for(int i = 2*n;i >= 0;i--)  //逆序输出,去掉前导零
    {
        if(flag && a[i] == 0)  continue;
         printf("%d", a[i]);
         flag = false;
    }

    return 0;
}

猜你喜欢

转载自www.cnblogs.com/lfri/p/11242183.html
ab