FFT/NTT板子

FFT

#include<bits/stdc++.h>
using namespace std;
const int N=240002;
const double pi=acos(-1.0);
struct C{
	double x,y;
}a[N],b[N];
int lim=1,i,n,r[N],l,c[N];
char s[60002];
C operator +(C a,C b){return (C){a.x+b.x,a.y+b.y};}
C operator -(C a,C b){return (C){a.x-b.x,a.y-b.y};}
C operator *(C a,C b){return (C){a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};}
void fft(C *A,int opt){
	for (int i=0;i<lim;i++)
		if (i<r[i]) swap(A[i],A[r[i]]);
	for (int mid=1;mid<lim;mid<<=1){
		C wn=(C){cos(pi/mid),opt*sin(pi/mid)};
		for (int R=mid<<1,j=0;j<lim;j+=R){
			C w=(C){1,0};
			for (int k=0;k<mid;k++,w=w*wn){
				C x=A[j|k],y=w*A[j|k|mid];
				A[j|k]=x+y;
				A[j|k|mid]=x-y;
			}
		}
	}
}
int main(){
	scanf("%d%s",&n,s);
	n--;
	for (i=0;i<=n;i++) a[n-i].x=s[i]^48;
	scanf("%s",s);
	for (i=0;i<=n;i++) b[n-i].x=s[i]^48;
	for (;lim<=n<<1;lim<<=1,l++);
	for (i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	fft(a,1);fft(b,1);
	for (i=0;i<lim;i++) a[i]=a[i]*b[i];
	fft(a,-1);
	n<<=1;
	for (i=0;i<=n;i++) c[i]=(int)(a[i].x/lim+0.5);
	for (i=0;i<=n;i++) c[i+1]+=c[i]/10,c[i]%=10;
	n++;
	while (!c[n]) n--;
	for (i=n;i>=0;i--) printf("%d",c[i]);
}

NTT

#include<bits/stdc++.h>
using namespace std;
const int N=240002,M=998244353,g=3,gi=332748118;
int lim=1,i,n,r[N],l,c[N],a[N],b[N],inv;
char s[60002];
inline int inc(int x,int y){x+=y;if(x>=M)x-=M;return x;}
inline int dec(int x,int y){x-=y;if(x<0)x+=M;return x;}
inline int mul(int x,int y){return 1ll*x*y%M;}
inline int pw(int x,int y){
	int z=1;
	for (;y;y>>=1,x=mul(x,x))
		if (y&1) z=mul(z,x);
	return z;
}
void ntt(int *A,int opt){
	for (int i=0;i<lim;i++)
		if (i<r[i]) swap(A[i],A[r[i]]);
	for (int mid=1,pp=M-1>>1;mid<lim;mid<<=1,pp>>=1){
		int wn=pw(opt==1?g:gi,pp);
		for (int R=mid<<1,j=0;j<lim;j+=R){
			int w=1;
			for (int k=0;k<mid;k++,w=mul(w,wn)){
				int x=A[j|k],y=mul(w,A[j|k|mid]);
				A[j|k]=inc(x,y),A[j|k|mid]=dec(x,y);
			}
		}
	}
}
int main(){
	scanf("%d%s",&n,s);
	n--;
	for (i=0;i<=n;i++) a[n-i]=s[i]^48;
	scanf("%s",s);
	for (i=0;i<=n;i++) b[n-i]=s[i]^48;
	for (;lim<=n<<1;lim<<=1,l++);
	for (i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	ntt(a,1),ntt(b,1);
	for (i=0;i<lim;i++) a[i]=mul(a[i],b[i]);
	ntt(a,-1);
	inv=pw(lim,M-2);
	n<<=1;
	for (i=0;i<=n;i++) c[i]=mul(a[i],inv);
	for (i=0;i<=n;i++) c[i+1]+=c[i]/10,c[i]%=10;
	n++;
	while (!c[n]) n--;
	for (i=n;i>=0;i--) printf("%d",c[i]);
}

猜你喜欢

转载自blog.csdn.net/xumingyang0/article/details/88420656