【洛谷P4245】任意模数NTT

任意模数NTT

首先我们取三个模数,使得它们的乘积大于 n P 2 7 2 2 6 + 1 998244353 479 2 21 + 1 这三个数就挺合适的,它们互质且原根都是3。
然后对于结果的每一位,我们就得到了中国剩余定理形式的式子:

a n s a 1 ( mod m 1 )

a n s a 2 ( mod m 2 )

a n s a 3 ( mod m 3 )

当然, m 1 m 2 m 3 敲大的,你肯定不能暴力合并,所以可以先把前两项合并了,得到 a n s A ( mod M ) ,然后就有 M x + A a 3 ( mod m 3 ) ,然后求出 x = M 1 ( a 3 A ) 之后,得到 a n s = M x + A ,此时在计算的时候直接模题目给定的模数。

代码

#include<bits/stdc++.h>
using namespace std;
#define RI register int
int read() {
    int q=0;char ch=' ';
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
    return q;
}
typedef long long LL;
const int N=262150,mm[3]={7*(1<<26)+1,998244353,479*(1<<21)+1},G=3;
int n,m,kn=1,len,mod;
int a[N],b[N],k1[N],k2[N],ans[3][N],rev[N];
int ksm(int x,int y,int p) {
    int re=1;
    for(;y;y>>=1,x=1LL*x*x%p) if(y&1) re=1LL*re*x%p;
    return re;
}
LL ksc(LL x,LL y,LL p) {return (x*y-(LL)((long double)x/p*y+1e-8)*p+p)%p;}
void NTT(int *a,int n,int p,int x) {
    for(RI i=0;i<n;++i) if(rev[i]>i) swap(a[i],a[rev[i]]);
    for(RI i=1;i<n;i<<=1) {
        int gn=ksm(G,(p-1)/(i<<1),p);
        for(RI j=0;j<n;j+=i<<1) {
            int g=1,t1,t2;
            for(RI k=0;k<i;++k,g=1LL*g*gn%p) {
                t1=a[j+k]%p,t2=1LL*g*a[j+i+k]%p;
                a[j+k]=(t1+t2)%p,a[j+i+k]=(t1-t2+p)%p;
            }
        }
    }
    if(x==1) return;
    int inv=ksm(n,p-2,p);reverse(a+1,a+n);
    for(RI i=0;i<n;++i) a[i]=1LL*a[i]*inv%p;
}
void work(int o) {
    for(RI i=0;i<kn;++i) k1[i]=a[i],k2[i]=b[i];
    NTT(k1,kn,mm[o],1),NTT(k2,kn,mm[o],1);
    for(RI i=0;i<kn;++i) ans[o][i]=1LL*k1[i]*k2[i]%mm[o];
    NTT(ans[o],kn,mm[o],-1);
}
int main()
{
    n=read(),m=read(),mod=read();
    for(RI i=0;i<=n;++i) a[i]=read();
    for(RI i=0;i<=m;++i) b[i]=read();
    while(kn<=n+m) kn<<=1,++len;
    for(RI i=1;i<kn;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
    for(RI i=0;i<3;++i) work(i);
    LL M=1LL*mm[0]*mm[1];
    LL kl1=ksc(mm[1],ksm(mm[1]%mm[0],mm[0]-2,mm[0]),M);
    LL kl2=ksc(mm[0],ksm(mm[0]%mm[1],mm[1]-2,mm[1]),M);
    for(RI i=0;i<=n+m;++i) {
        int t0=ksm(ans[0][i],mm[1]-2,mm[1]),t1=ksm(ans[1][i],mm[0]-2,mm[0]);
        LL A=(ksc(kl1,ans[0][i],M)+ksc(kl2,ans[1][i],M))%M;
        LL k=((ans[2][i]-A)%mm[2]+mm[2])%mm[2]*ksm(M%mm[2],mm[2]-2,mm[2])%mm[2];
        printf("%lld ",((M%mod)*(k%mod)%mod+A%mod)%mod);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/litble/article/details/81568686