洛谷 4245 【模板】任意模数NTT

题目:https://www.luogu.org/problemnew/show/P4245

大概是用3个模数分别做一遍,用中国剩余定理合并。

前两个合并起来变成一个 long long 的模数,再要和第三个合并的话就爆 long long ,所以可以用一种让两个模数的乘积不出现的方法:https://blog.csdn.net/qq_35950004/article/details/79477797

 x*m1+a1 = -y*m2 + a2  <==>  x*m1+y*m2 = a2-a1  <==>  x*m1 = a2-a1 (mod m2)  <==> x=(a2-a1)*m1^{-1} (mod m2)

然后根据该博客里的证明,在mod m2意义下算出来的 x 就是真的 x 。这样的话答案就是 x*m1+a1 ,可以在快速乘的过程中对题目中给的模数取模,就不会爆 long long 啦。

注意输入的 a[ ] 和 b[ ] 不能 ntt( ,0, ) 之后再 ntt( ,1, ) 回来,因为值已经模了刚才那个模数了;所以要多开一些数组。

注意输入进 mul 里的 a 和 b 应该是正的,不然没法 b>>=1 之类的。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=1e5+5;
int m[3]={998244353,1004535809,469762049};
int n0,n1,mod,len,r[N<<2],a[3][N<<2],b[3][N<<2],c[3][N<<2];
ll M=(ll)m[0]*m[1],d[N<<1];
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
void upd(ll &x,ll md){x>=md?x-=md:0;}
void upd(int &x,ll md){x>=md?x-=md:0;}
ll mul(ll a,ll b,ll md)
{
  a%=md; b%=md;//
  ll ret=0;while(b){if(b&1ll)ret+=a,upd(ret,md);a+=a;upd(a,md);b>>=1ll;}return ret;
}
ll pw(ll x,ll k,ll md)
{ll ret=1;while(k){if(k&1ll)ret=mul(ret,x,md);x=mul(x,x,md);k>>=1ll;}return ret;}
void ntt(int *a,bool fx,int md)
{
  for(int i=0;i<len;i++)
    if(i<r[i])swap(a[i],a[r[i]]);
  for(int R=2;R<=len;R<<=1)
    {
      int m=R>>1;
      int Wn=pw(3,(md-1)/R,md);
      fx?Wn=pw(Wn,md-2,md):0;
      for(int i=0;i<len;i+=R)
    for(int j=0,w=1;j<m;j++,w=(ll)w*Wn%md)
      {
        int tmp=(ll)w*a[i+m+j]%md;
        a[i+m+j]=a[i+j]+md-tmp; upd(a[i+m+j],md);
        a[i+j]=a[i+j]+tmp; upd(a[i+j],md);
      }
    }
  if(!fx)return; int inv=pw(len,md-2,md);
  for(int i=0;i<len;i++) a[i]=(ll)a[i]*inv%md;
}
int main()
{
  n0=rdn()+1; n1=rdn()+1; mod=rdn();
  for(int i=0;i<n0;i++)a[0][i]=a[1][i]=a[2][i]=rdn();
  for(int i=0;i<n1;i++)b[0][i]=b[1][i]=b[2][i]=rdn();
  for(len=1;len<=n0+n1;len<<=1);
  for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)+((i&1)?len>>1:0);
  for(int i=0;i<3;i++)//don't ntt(a,1,m[i]) for it can't return(already mod)
    {
      ntt(a[i],0,m[i]); ntt(b[i],0,m[i]);
      for(int j=0;j<len;j++)c[i][j]=(ll)a[i][j]*b[i][j]%m[i];
      ntt(c[i],1,m[i]);
    }

  ll inv=pw(m[0],m[1]-2,m[1]),t;
  for(int i=0,lm=n0+n1-1;i<lm;i++)
    {
      t=mul((c[1][i]-c[0][i])%m[1]+m[1],inv,m[1]);
      d[i]=(mul(t,m[0],M)+c[0][i])%M;
    }
  inv=pw(M,m[2]-2,m[2]);
  for(int i=0,lm=n0+n1-1;i<lm;i++)
    {
      t=mul((c[2][i]-d[i])%m[2]+m[2],inv,m[2]);
      d[i]=(mul(t,M,mod)+d[i])%mod;
      printf("%lld ",d[i]);
    }
  puts(""); return 0;
}

猜你喜欢

转载自www.cnblogs.com/Narh/p/10035325.html