luogu P4512 多项式除法 (模板题、FFT、多项式求逆)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/suncongbo/article/details/84853342

题目链接: https://www.luogu.org/problemnew/show/P4512

没想到这算法这么蠢。。一点都不难啊。。我连这都推不出来我是不是没救了

这个多项式满足 A ( x ) = B ( x ) Q ( x ) + R ( x ) A(x)=B(x)Q(x)+R(x) , 如果已知 R ( x ) R(x) 0 0 , 那显然很好处理,求个逆就行了。
那如果有余数呢?很简单,如果我们把这个多项式的系数翻转(reverse),那么 R ( x ) R(x) 就从低次项变成了高次项,低次项就不再受 R ( x ) R(x) 的而影响了。
这是我们的基本思路。下面我们来形式化这个过程:
A ( x ) = B ( x ) Q ( x ) + R ( x ) A(x)=B(x)Q(x)+R(x)
对于 n n 次多项式 F ( x ) F(x) F R ( x ) = x n F ( 1 x ) F_R(x)=x^nF(\frac{1}{x}) (这就是前面所说的reverse操作)
则有 x n A ( 1 x ) = x m B ( 1 x ) x n m Q ( 1 x ) + x m 1 R ( 1 x ) x n m + 1 x^nA(\frac{1}{x})=x^mB(\frac{1}{x})x^{n-m}Q(\frac{1}{x})+x^{m-1}R(\frac{1}{x})x^{n-m+1}
A R ( x ) = B R ( x ) Q R ( x ) + x n m + 1 R R ( x ) A_R(x)=B_R(x)Q_R(x)+x^{n-m+1}R_R(x)
A R ( x ) B R ( x ) Q R ( x ) ( m o d    x n m + 1 ) A_R(x)\equiv B_R(x)Q_R(x) (\mod x^{n-m+1})
于是我们求 B R ( x ) B_R(x) m o d    x n m + 1 \mod x^{n-m+1} 意义下的逆,然后乘以 A R ( x ) A_R(x) 即可求出 Q R ( x ) Q_R(x) , 从而得到 Q ( x ) Q(x) .
然后用 R ( x ) = A ( x ) B ( x ) Q ( x ) R(x)=A(x)-B(x)Q(x) 即可求出 R ( x ) R(x) .
(虽然算法简单但是要注意的地方还挺多……容易错。)
时间复杂度 O ( n log n ) O(n\log n) , 我写的进行了 24 24 倍常数的ntt.
空间复杂度 O ( n ) O(n) , 我的实现好像需要开 8 8 倍。

代码实现

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define llong long long
#define modinc(x) {if(x>=P) x-=P;}
using namespace std;

const int N = 1<<19;
const int LGN = 19;
const llong G = 3ll;
const int P = 998244353;
llong tmp1[N+3],tmp2[N+3],tmp3[N+3],tmp4[N+3];
llong tmp5[N+3],tmp6[N+3],tmp7[N+3],tmp8[N+3],tmp9[N+3];
llong a[N+3],b[N+3],q[N+3],r[N+3];
int id[N+3];
int n,m;

llong quickpow(llong x,llong y)
{
 llong cur = x,ret = 1ll;
 for(int i=0; y; i++)
 {
  if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
  cur = cur*cur%P;
 }
 return ret;
}
llong mulinv(llong x) {return quickpow(x,P-2)%P;}

void initid(int dgr)
{
 int len = 0; for(int i=0; i<=LGN; i++) if((1<<i)==dgr) {len = i; break;}
 id[0] = 0;
 for(int i=1; i<dgr; i++) id[i] = (id[i>>1]>>1)|(i&1)<<(len-1);
}

int getdgr(int x)
{
 int ret = 1;
 while(ret<=x) ret<<=1;
 return ret;
}

void ntt(int dgr,int coe,llong poly[],llong ret[])
{
 initid(dgr);
 for(int i=0; i<dgr; i++) ret[i] = poly[i];
 for(int i=0; i<dgr; i++) if(i<id[i]) swap(ret[i],ret[id[i]]);
 for(int i=1; i<=(dgr>>1); i<<=1)
 {
  llong tmp = quickpow(G,(P-1)/(i<<1));
  if(coe==-1) tmp = mulinv(tmp);
  for(int j=0; j<dgr; j+=(i<<1))
  {
   llong expn = 1ll;
   for(int k=0; k<i; k++)
   {
    llong x = ret[j+k],y = ret[j+i+k]*expn%P;
    ret[j+k] = x+y; modinc(ret[j+k]);
    ret[j+i+k] = x-y+P; modinc(ret[j+i+k]);
    expn = (expn*tmp)%P;
   }
  }
 }
 if(coe==-1)
 {
  llong tmp = mulinv(dgr);
  for(int j=0; j<dgr; j++) ret[j] = ret[j]*tmp%P;
 }
}

void polyinv(int dgr,llong poly[],llong ret[])
{
 for(int i=0; i<dgr; i++) ret[i] = 0ll;
 ret[0] = mulinv(poly[0]);
 for(int i=1; i<=(dgr>>1); i<<=1)
 {
  for(int j=0; j<(i<<2); j++) tmp1[j] = j<i ? ret[j] : 0ll;
  for(int j=0; j<(i<<2); j++) tmp2[j] = j<(i<<1) ? poly[j] : 0ll;
  ntt((i<<2),1,tmp1,tmp3); ntt((i<<2),1,tmp2,tmp4);
  for(int j=0; j<(i<<2); j++) tmp3[j] = tmp3[j]*tmp3[j]%P*tmp4[j]%P;
  ntt((i<<2),-1,tmp3,tmp4);
  for(int j=0; j<(i<<1); j++) ret[j] = (tmp1[j]+tmp1[j]-tmp4[j]+P)%P; 
 }
 for(int j=dgr; j<(dgr<<1); j++) ret[j] = 0ll;
}

void polyrev(int dgr,llong poly[],llong ret[])
{
 for(int i=0; i<dgr; i++) ret[i] = poly[dgr-1-i];
}

void polydiv(int dgr1,int dgr2,llong poly1[],llong poly2[],llong ret1[],llong ret2[])
{
 int _dgr1 = getdgr(dgr1),_dgr2 = getdgr(dgr2);
 polyrev(dgr2,poly2,tmp5); polyrev(dgr1,poly1,tmp9);
 polyinv(_dgr1,tmp5,tmp6);
 for(int i=dgr1-dgr2+1; i<(_dgr1<<1); i++) tmp6[i] = 0ll;
 ntt(_dgr1<<1,1,tmp9,tmp7); ntt(_dgr1<<1,1,tmp6,tmp8);
 for(int i=0; i<(_dgr1<<1); i++) tmp7[i] = tmp7[i]*tmp8[i]%P;
 ntt(_dgr1<<1,-1,tmp7,tmp8);
 for(int i=dgr1-dgr2+1; i<(_dgr1<<1); i++) tmp8[i] = 0ll;
 polyrev(dgr1-dgr2+1,tmp8,ret1);
 ntt(_dgr1<<1,1,poly2,tmp7); ntt(_dgr1<<1,1,ret1,tmp8);
 for(int i=0; i<(_dgr1<<1); i++) tmp7[i] = tmp7[i]*tmp8[i]%P;
 ntt(_dgr1<<1,-1,tmp7,ret2);
 for(int i=dgr2; i<(_dgr1<<1); i++) ret2[i] = 0ll;
 for(int i=0; i<dgr2-1; i++) ret2[i] = (poly1[i]-ret2[i]+P)%P;
}

int main()
{
 scanf("%d%d",&n,&m); n++; m++;
 for(int i=0; i<n; i++) scanf("%lld",&a[i]);
 for(int i=0; i<m; i++) scanf("%lld",&b[i]);
 int dgr1 = getdgr(n),dgr2 = getdgr(m);
 polydiv(n,m,a,b,q,r);
 for(int i=0; i<=n-m; i++) printf("%lld ",q[i]); puts("");
 for(int i=0; i<m-1; i++) printf("%lld ",r[i]);
 return 0;
}

猜你喜欢

转载自blog.csdn.net/suncongbo/article/details/84853342