HDU 1402(FFT,NNT)

题意:大数相乘。

主要是存板子的,一个FFT,一个NNT。其中FFT可能存在精度问题。有的时候会被卡精度,比如题目要求取模时,这时我们一般用NNT。

//FFT
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <math.h>
using namespace std;
typedef long long ll;

const double PI=acos(-1.0);
struct complex{
    double r,i;
    complex(double _r=0,double _i=0)
    {
        r=_r;i=_i;
    }
    complex operator +(const complex &b)
    {
        return complex(r+b.r,i+b.i);
    }
    complex operator -(const complex &b)
    {
        return complex(r-b.r,i-b.i);
    }
    complex operator *(const complex &b)
    {
        return complex(r*b.r-i*b.i,r*b.i+i*b.r);
    }
};

void change(complex y[],int len)
{
    for(int i=1,j=len/2;i<len-1;i++)
    {
        if(i<j) swap(y[i],y[j]);
        int k=len/2;
        while(j>=k)
        {
            j-=k;
            k/=2;
        }
        if(j<k) j+=k;
    }
}

void fft(complex y[],int len,int on)
{
    change(y,len);
    for(int h=2;h<=len;h<<=1)
    {
        complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
        for(int j=0;j<len;j+=h)
        {
            complex w(1,0);
            for(int k=j;k<j+h/2;k++)
            {
                complex u=y[k];
                complex t=w*y[k+h/2];
                y[k]=u+t;
                y[k+h/2]=u-t;
                w=w*wn;
            }
        }
    }
    if(on==-1)
    {
        for(int i=0;i<len;i++) y[i].r/=len;
    }
}
const int maxn=200010;
char a[maxn/2],b[maxn/2];
complex x1[maxn],x2[maxn];
int ans[maxn];
int main()
{
    while(scanf("%s",a)!=EOF)
    {
        scanf("%s",b);

        int len1 = strlen(a);
        int len2 = strlen(b);
        int len = 1;
        while(len < len1*2 || len < len2*2)len<<=1;
        
        for(int i=0;i<len1;i++)
            x1[i]=complex(a[len1-i-1]-'0',0);
            
        for(int i=len1;i<len;i++)
            x1[i]=complex(0,0);
            
        for(int i=0;i<len2;i++)
            x2[i]=complex(b[len2-i-1]-'0',0);
            
        for(int i=len2;i<len;i++)
            x2[i]=complex(0,0);
            
        fft(x1,len,1);
        fft(x2,len,1);
        
        for(int i=0;i<len;i++)
        x1[i]=x1[i]*x2[i];
        fft(x1,len,-1);
        
        for(int i=0;i<len;i++)
        ans[i]=(int)(x1[i].r+0.5);
        
        for(int i=0;i<len;i++)
        {
            ans[i+1]+=ans[i]/10;
            ans[i]%=10;
        }
        int flag=0;
        len=len1+len2-1;
        while(ans[len] <= 0 && len > 0)len--;
        for(int i = len;i >= 0;i--)
            printf("%c",ans[i]+'0');
        printf("\n");
    }
    return 0;
}
//NNT
#include<bits/stdc++.h>

using namespace std;
typedef long long ll;

const int maxn = 2e5 + 50;
const ll mod = 998244353;

inline ll qpow(ll a, ll b)
{
	ll sum = 1;
	while (b)
	{
		if (b & 1)
			sum = sum * a % mod;
		b >>= 1;
		a = a * a % mod;
	}
	return sum;
}

inline ll Inv(ll a, ll _mod)
{
	return qpow(a, _mod - 2);
}


struct NTT
{
	int rev[maxn], dig[105];
	int N, L;
	ll g;
	void init_rev(int n)
	{
		//初始化原根
		g = 3;
		for (N = 1, L = 0; N <= n; N <<= 1, L++);
		memset(dig,0,sizeof(int)*(L+1));
		for (int i = 0; i < N; i++)
		{
			rev[i] = 0;
			int len = 0;
			for (int t = i; t; t >>= 1)
				dig[len++] = t & 1;
			for (int j = 0; j < L; j++)
				rev[i] = (rev[i] << 1) | dig[j];
		}
	}

	void DFT(vector<ll>&a , int flag)
	{
		for (int i = 0; i < N; i++)
			if (i < rev[i])
				swap(a[i], a[rev[i]]);

		for (int l = 1; l < N; l <<= 1)
		{
			ll wn;
			if (flag == 1)
				wn = qpow(g, (mod - 1) / (2*l));
			else
				wn = qpow(g, mod - 1 - (mod - 1) / (2*l));
			for (int k = 0; k < N; k += l*2)
			{
				ll w = 1;
				ll x,y;
				for (int j = k; j < k + l; j++)
				{
					x = a[j];
					y = a[j+l] * w % mod;
					a[j] = (x + y) % mod;
					a[j + l] = (x - y + mod) % mod;
					w = w * wn % mod;
				}
			}
		}
		if (flag == -1)
		{
			ll x = Inv(N, mod);
			for (int i = 0; i < N; i++)
				a[i] = a[i] * x % mod;
		}
	}

	void mul(vector<ll>& a,vector<ll>& b,int m)
	{
		init_rev(m);
		a.resize(N);
		b.resize(N);
		DFT(a, 1);
		DFT(b, 1);
		for (int i = 0; i < N; i++)
			a[i] = a[i] * b[i]%mod;
		DFT(a, -1);
		int len = N;
		while(a[len]==0) len--;
		a.resize(len+1);
	}
} ntt;

vector<ll> v[maxn];
char s[100006];
ll ans[200505];
int main()
{
	while(scanf("%s",s)!=EOF)
	{
		v[1].clear();
		v[2].clear();
		
		int len=strlen(s);
		v[1].resize(len);
		for(int j=0; j<len; j++)
		{
			v[1][len-1-j] = s[j]-'0';
		}
		scanf("%s",s);
		len=strlen(s);
		v[2].resize(len);
		for(int j=0; j<len; j++)
		{
			v[2][len-1-j] = s[j]-'0';
		}
		
		ntt.mul(v[1],v[2],v[1].size()+v[2].size());
		memset(ans,0,sizeof(ans));
		for(int i=0; i<v[1].size()*2; i++)
		{
			if(v[1].size()>i)
            ans[i]+=v[1][i];
            if(ans[i]>=10)
            {
                ans[i+1]+=ans[i]/10;
                ans[i]%=10;
            }
		}
		
		int e=0;
        for(int i=v[1].size()*2-1;i>=0;i--)
        {
            if(ans[i])
            {
                e=i;
                break;
            }
        }
        for(int i=e;i>=0;i--)
        {
            printf("%lld",ans[i]);
        }
        printf("\n");
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_37632935/article/details/81561672