[NTT] Luogu P4245 任意模数NTT

 

 

 题解

  • 用三模数NTT做,有点小细节,其他都是模板了

 

代码

 1 #include <cstdio>
 2 #include <iostream>
 3 #define ll long long
 4 using namespace std;
 5 int const N=(1<<19);
 6 int n,m,lim,l,rev[N],a[5][N],b[5][N],p[5]={0,469762049,998244353,1004535809};
 7 ll mo;
 8 ll pd(ll x,int mo) { while (x>=mo) x-=mo; while (x<0) x+=mo; return x; }
 9 ll mul(ll a,ll b,int mo) 
10 { 
11     ll r=0; a%=mo,b%=mo;
12     if (a<0) a+=mo; if (b<0) b+=mo;
13     for (;b;b>>=1ll,a=(a+a)%mo) if (b&1) r=(r+a)%mo; 
14     return r; 
15 }
16 ll ksm(ll a,ll b,int mo) { ll r=1; for (;b;b>>=1,a=mul(a,a,mo)) if (b&1) r=mul(r,a,mo); return r; }
17 void ntt(int *a,int f,int p)
18 {
19     for (int i=0;i<lim;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);
20     for (int mid=1;mid<lim;mid<<=1)
21     {
22         int q=ksm(3,(p-1)/(mid<<1),p);
23         if (f==-1) q=ksm(q,p-2,p);
24         for (int j=0,len=(mid<<1);j<lim;j+=len)
25         {
26             int w=1;
27             for (int k=0,x,y;k<mid;k++,w=(ll)w*q%p) x=a[j+k],y=(ll)w*a[j+mid+k]%p,a[j+k]=pd(x+y,p),a[j+mid+k]=pd(x-y,p);
28         }
29     }
30     if (f==1) return;
31     int inv=ksm(lim,p-2,p); 
32     for (int i=0;i<lim;i++) a[i]=(ll)a[i]*inv%p;
33 }
34 ll uni(ll r1,ll r2,ll m1,ll m2,int f,int v)
35 {
36     ll k=mul(r2-r1,v,m2);
37     if (!f) return (r1+k*m1)%(m1*m2);
38     return pd((r1+mul(k,m1,mo))%mo,mo);
39 }
40 int main()
41 {
42     freopen("data.in","r",stdin),scanf("%d%d%lld",&n,&m,&mo);
43     for (int i=0;i<=n;i++) scanf("%d",&a[1][i]),a[2][i]=a[3][i]=a[1][i];
44     for (int i=0;i<=m;i++) scanf("%d",&b[1][i]),b[2][i]=b[3][i]=b[1][i];
45     lim=1; while (lim<=n+m+2) lim*=2,l++;
46     for (int i=0;i<lim;i++) rev[i]=((rev[i>>1]>>1)|((i&1)<<(l-1)));
47     for (int i=1;i<=3;i++)
48     {
49         ntt(a[i],1,p[i]),ntt(b[i],1,p[i]);
50         for (int j=0;j<lim;j++) a[i][j]=(ll)a[i][j]*b[i][j]%p[i];
51         ntt(a[i],-1,p[i]);
52     }
53     int inv1=ksm(p[1],p[2]-2,p[2]),inv2=ksm((ll)p[1]*p[2],p[3]-2,p[3]);
54     for (int i=0;i<=n+m;i++)
55     {
56         ll ans=uni(a[1][i],a[2][i],p[1],p[2],0,inv1);
57         ans=uni(ans,a[3][i],(ll)p[1]*p[2],p[3],1,inv2),printf("%lld ",ans);
58     }
59 }

 

猜你喜欢

转载自www.cnblogs.com/Comfortable/p/11363781.html
0条评论
添加一条新回复