NTT任意模数模板(+O(1)快速乘)

NTT任意模数的方法其实有点取巧。

两个数列每个有n个数,每个数的大小最多是10^9。

如果没有模数,那么卷积过后每个位置的答案一定小于10^9*10^9*n,差不多是10^24左右

那么就有一个神奇的做法,选3个乘积大于10^24的NTT模数,分别做一次,得到每个位上模意义下的答案,

然后用中国剩余定理得到模上三个质数乘积的答案。

因为答案显然小于三个质数乘积,那么模上三个质数乘积的答案就是这个数应该的值。

不过这个值可能会超long long(及时不超,对于乘积大于long long的三个质数做中国剩余定理也不是一件小事)

考虑先将两个模数意义下的答案合并,

现在我们还剩两个模数,一个为long long,一个为int

不能中国剩余定理硬上了。

设模数为P1(longlong) ,P2(int), 余数为a1,a2

设答案ANS=P1*K+a1=P2*Q+a2

那么K*P1=P2*Q+(a2-a1)

K*P1 % P2=a2-a1

a1-a2为常数

用同余方程的解法即可解出K模P2(int)意义下的值

又有ANS<P1*P2(之前已证)

so K*P1+a1<P1*P2

显然K<P2

所以原本答案K的值只能为模P2意义下的值

所以我们就求出K了,然后可以不用高精度就算出ANS%MOD(MOD为任意模数)

但是,

回顾整个过程,附加条件非常多。。。。。。

首先每个数<=10^9(再大或许可以通过增加模数的方法解决,但是CRT时可就不能回避高精度取模了,常数捉急)

然后K的值必须为非负.(如果为负数那么就有两个可能的答案了,这是你用第一条性质怎样都无法回避的)

其次你需要解决两个long long相乘mod long long

用二分乘法会T(常数啊),可以用接近作弊的O(1)long long乘法取模:

//O(1)快速乘
LL mul(LL a,LL b,LL P)
{
	a=(a%P+P)%P,b=(b%P+P)%P;
	return ((a*b-(LL)((long double)a/P*b+1e-6)*P)%P+P)%P;
}

主要原理是mod后的答案用公式 A % B=A-floor(A/B)*B来算。

注意其中A和floor(A/B)*B都是可能爆long long的。

但是因为减法,所以无论是两个都不溢出还是两个都溢出亦或是一个溢出另一个不溢出,都没有关系。。。。。。。

(好像2009集训队论文中关于底层优化的那篇上有)

模板:

#include<cstdio>
#include<cstring>
#include<cctype>
//#include<ctime>
#include<algorithm>
#define maxn 300005
#define LL long long
#define RealMod 1000000007
using namespace std;
 
int n,m;
int P[3]={998244353,1004535809,469762049},G[3]={3,3,3},invG[3],wn[3][2][maxn];
 
inline int pow(int base,int k,int P)
{
	int ret=1;
	for(;k;k>>=1,base=1ll*base*base%P) if(k&1) ret=1ll*ret*base%P;
	return ret;
}
 
inline void Prework(int id)
{
	invG[id]=pow(G[id],P[id]-2,P[id]);
	for(int i=1;i<24;i++) wn[id][1][i]=pow(G[id],(P[id]-1)/(1<<i),P[id]),wn[id][0][i]=pow(invG[id],(P[id]-1)/(1<<i),P[id]);
}
 
inline void NTT(int *A,int n,int typ,int id)
{
	for(int i=0,j=0,k;i<n;i++)
	{
		if(i<j) swap(A[i],A[j]);
		for(k=n>>1;k;k>>=1) if((j^=k)>=k) break;
	}
 
	for(int i=1,j,k,len,w,x,y;1<<i<=n;i++)	
	{
		len=1<<(i-1);
		for(j=0,w=1;j<n;j+=1<<i,w=1)
			for(k=0;k<len;k++,w=1ll*w*wn[id][typ][i]%P[id])
			{
				x=A[j+k],y=1ll*A[j+k+len]*w%P[id];
				A[j+k]=(x+y)%P[id];
				A[j+k+len]=(x-y+P[id])%P[id];
			}
	}
	
	if(typ==0)
		for(int i=0,inv=pow(n,P[id]-2,P[id]);i<n;i++)
			A[i]=1ll*A[i]*inv%P[id];
}
 
void mul(int *ret,int *A,int lena,int *B,int lenb,int id)
{
	static int seq1[maxn],seq2[maxn];
	int n=1;for(;n<=lena+lenb;n<<=1);
	for(int i=0;i<n;i++) 
	{
		if(i<=lena) seq1[i]=A[i]; else seq1[i]=0;
		if(i<=lenb) seq2[i]=B[i]; else seq2[i]=0;
	}
	NTT(seq1,n,1,id);NTT(seq2,n,1,id);
	for(int i=0;i<n;i++) ret[i]=1ll*seq1[i]*seq2[i]%P[id];
	NTT(ret,n,0,id);
}
 
//O(1)快速乘
LL mul(LL a,LL b,LL P)
{
	a=(a%P+P)%P,b=(b%P+P)%P;
	return ((a*b-(LL)((long double)a/P*b+1e-6)*P)%P+P)%P;
}
/*
long long mul (long long a, long long b, long long mod) {
	a%=mod,b%=mod;
    if (b == 0)
        return 0;
    long long ans = mul (a, b>>1, mod);
    ans = ans*2%mod;
    if (b&1) ans += a, ans %= mod;
    return (ans+mod)%mod;
}*/
 
 
int a[maxn],b[maxn],c[3][maxn];
LL Mod=1ll*P[0]*P[1];
LL inv1=pow(P[0],P[1]-2,P[1]),inv2=pow(P[1],P[0]-2,P[0]),inv=pow(Mod%P[2],P[2]-2,P[2]);
inline void solve(int i)
{
	LL C=(mul(1ll*c[0][i]*P[1]%Mod,inv2,Mod)+mul(1ll*c[1][i]*P[0]%Mod,inv1,Mod))%Mod;
	LL K=1ll*((1ll*c[2][i]-(C%P[2]))%P[2])*(inv%P[2])%P[2];
	c[0][i]=(((K%RealMod+RealMod)*(Mod%RealMod)%RealMod+C)%RealMod);
}
 
int main()
{
	
	//freopen("1.in","r",stdin);
	
	//int t1=clock();
	
	scanf("%d%d",&n,&m);
	for(int i=0;i<=n;i++) scanf("%d",&a[i]);
	for(int j=0;j<=m;j++) scanf("%d",&b[j]);
	for(int i=0;i<3;i++) Prework(i),mul(c[i],a,n,b,m,i);
	for(int i=0;i<=n+m;i++) 
		solve(i); 
	for(int i=0;i<n+m;i++)
	{
		int tmp=(c[0][i]+RealMod)%RealMod;
		printf("%d ",tmp);
	}
	printf("%d\n",(c[0][n+m]+RealMod)%RealMod);
	
	//printf("%d\n",clock()-t1);
}

原文:https://blog.csdn.net/qq_35950004/article/details/79477797

猜你喜欢

转载自blog.csdn.net/tianwei0822/article/details/82346338