2018.11.14 uoj#34. 多项式乘法(ntt)

版权声明:随意转载哦......但还是请注明出处吧: https://blog.csdn.net/dreaming__ldx/article/details/84145535

传送门
今天学习 n t t ntt
其实递归方法和 f f t fft 是完全相同的。
只不过 f f t fft 的单位根用的是复数中的东西,而 n t t ntt 用的是数论里面有相同性质的原根。
代码:

#include<bits/stdc++.h>
using namespace std;
inline int read(){
	int ans=0;
	 char ch=getchar();
	 while(!isdigit(ch))ch=getchar();
	 while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
	 return ans;
}
typedef long long ll;
const int mod=998244353,N=4e5+5;
int pos[N],n,m,a[N],b[N],lim=1,tim=0;
inline int ksm(int a,int p){int ret=1;for(;p;p>>=1,a=(ll)a*a%mod)if(p&1)ret=(ll)ret*a%mod;return ret;}
inline void ntt(int a[],int type){
	for(int i=0;i<lim;++i)if(i<pos[i])swap(a[i],a[pos[i]]);
	int mult=(mod-1)>>1,typ=type==1?3:(mod+1)/3;
	for(int wn,mid=1;mid<lim;mid<<=1,mult>>=1){
		wn=ksm(typ,mult);
		for(int j=0,len=mid<<1;j<lim;j+=len){
			int w=1;
			for(int k=0;k<mid;++k,w=(ll)w*wn%mod){
				int a0=a[j+k],a1=(ll)a[j+k+mid]*w%mod;
				a[j+k]=(a0+a1)%mod,a[j+k+mid]=(a0-a1+mod)%mod;
			}
		}
	}
}
int main(){
	n=read(),m=read();
	for(int i=0;i<=n;++i)a[i]=read();
	for(int i=0;i<=m;++i)b[i]=read();
	while(lim<=n+m)lim<<=1,++tim;
	for(int i=0;i<lim;++i)pos[i]=(pos[i>>1]>>1)|((i&1)<<(tim-1));
	ntt(a,1),ntt(b,1);
	for(int i=0;i<lim;++i)a[i]=(ll)a[i]*b[i]%mod;
	ntt(a,-1);
	int inv=ksm(lim,mod-2);
	for(int i=0;i<=n+m;++i)printf("%d ",(ll)a[i]*inv%mod);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/dreaming__ldx/article/details/84145535