【笔记】多项式求逆

Problem

对于一个多项式 a ( x ) ,求其逆元 b ( x ) ,即 a ( x ) b ( x ) 1 ( mod x n )

Solution

对于单个元素的逆元我们是会求的,比如说一个数 t 的逆元在膜质数意义下为 t p 2

但现在要求求一个多项式的逆元,联想到在模数为 x 时可以快速求得其逆元为 a ( 0 ) 1 ,可以考虑从这里开始递推

对于题目可设
a b 1 ( mod x 2 p )
a c 1 ( mod x p )

即已知 a , c b

a b 1 ( mod x p )
a c 1 ( mod x p )

b c 0 ( mod x p )

b 2 2 b c + c 2 0 ( mod x 2 p )
同乘 a
a b 2 2 a b c + a c 2 0 ( mod x 2 p )
考虑到 a b 1 ( mod x 2 p )
b 2 c + a c 2 0 ( mod x 2 p )
b 2 c a c 2 ( mod x 2 p )

即得到了一个用 a , c 表示 b 的递推式

借助NTT的膜数乘法,时间复杂度为 O ( n log 2 n )

不过一开始觉得时间复杂度应该是 O ( n log 2 2 n )

后来发现复杂度应该是 O ( n log n + n 2 log n 2 + ) = O ( n log 2 n )

发现自己还是数学思维太弱了

Code

#include<algorithm>
#include<cstdio>
#include<cctype>
using namespace std;
#define rg register

template <typename _Tp> inline _Tp read(_Tp&x){
    rg char c11=getchar(),ob=0;x=0;
    while(c11^'-'&&!isdigit(c11))c11=getchar();if(c11=='-')c11=getchar(),ob=1;
    while(isdigit(c11))x=x*10+c11-'0',c11=getchar();if(ob)x=-x;return x;
}

const int N=2001000,G=3,p=998244353;
int a[N],b[N],c[N],rev[N];

inline int qpow(int A,int B){
    int res(1);
    while(B){
        if(B&1)res=1ll*res*A%p;
        A=1ll*A*A%p;B>>=1;
    }return res;
}

void ntt(int*a,int n,int f){
    for(rg int i=0;i<n;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(rg int i=1;i<n;i<<=1){
        int wn=qpow(G,(p-1)/(i<<1));
        for(rg int j=0;j<n;j+=(i<<1)){
            int w(1);
            for(rg int k=0;k<i;++k,w=1ll*w*wn%p){
                int x=a[j+k],y=1ll*w*a[j+k+i]%p;
                a[j+k]=(x+y)%p,a[j+k+i]=(x-y+p)%p;
            }
        }
    }
    if(f==1)return ;
    int tmp=qpow(n,p-2);reverse(a+1,a+n);
    for(rg int i=0;i<n;++i)a[i]=1ll*a[i]*tmp%p;
}

void work(int d,int*a,int*b){
    if(d==1){b[0]=qpow(a[0],p-2);return ;}
    work((d+1)>>1,a,b);
    int l(0),nn(1);
    while(nn<(d<<1))nn<<=1,++l;
    for(rg int i=1;i<nn;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
    for(rg int i=0;i<d;++i)c[i]=a[i];
    for(rg int i=d;i<nn;++i)c[i]=0;
    ntt(c,nn,1);ntt(b,nn,1);
    for(rg int i=0;i<nn;++i)
        b[i]=1ll*(2-1ll*b[i]*c[i]%p+p)%p*b[i]%p;
    ntt(b,nn,-1);
    for(rg int i=d;i<nn;++i)b[i]=0;
    return ;
}

int main(){
    int n;read(n);
    for(rg int i=0;i<n;++i)read(a[i]);
    work(n,a,b);
    for(rg int i=0;i<n;++i)printf("%d ",b[i]);
    putchar('\n');return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_40515553/article/details/80503328