多项式全操作小讲 COGS2189 帕秋莉的超级多项式

版权声明:这篇文章的作者是个蒟蒻,没有转载价值,如果要转说一下好了 https://blog.csdn.net/litble/article/details/81749788

乘法

戳我

求逆

戳我

开根

戳我

求导

将多项式 A 求导,求导结果是 B ,则 B i = ( i + 1 ) A i + 1

积分

将多项式 A 积分,积分结果是 B ,则 B i = A i 1 i

求ln

一般要求求ln的题,常数项都为0.
考虑求 G ( F ( x ) ) ,其中 G ( x ) = l n x
我们知道若有 y = f ( g ( x ) ) , u = g ( x ) ,则 d y d x = d y d u d u d x = f ( u ) g ( x )
也知道对 l n x 求导的结果是 1 x
所以对原式求导,得 F ( x ) F ( x ) ,然后积分即可。

求exp

求满足 e F ( x ) = G ( x ) F ( x ) 。则有 l n G ( x ) F ( x ) = 0 。考虑牛顿迭代,设 H ( x ) 为在 mod x n 2 下的解,则有 F ( x ) = ( 1 l n H ( x ) + G ( x ) ) H ( x )
会求ln和exp后,显然所有指数函数和对数函数都能搞了。

求幂

快速幂?复杂度不够优秀!
忘了说了,这种不断将 n 的规模缩小至 1 2 ,在每层上做 n l o g n 的计算,复杂度都是 O ( n l o g n ) 的,也就是上面这些操作都是 O ( n l o g n ) 的。但是如果你用快速幂求幂的话,就是 O ( n l o g 2 n ) 的了,这样不太好。
考虑 F k ( x ) = e l n F k ( x ) = e k l n F ( x ) ,变成了 O ( n l o g n ) 的。

放上代码后,再讲一点点。

#include<bits/stdc++.h>
using namespace std;
#define RI register int
int read() {
    int q=0;char ch=' ';
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
    return q;
}
const int N=262150,mod=998244353,inv2=499122177,G=3;
int n,K,kn=1;
int A[N],B[N],len[N],rev[N],inv[N];
int k1[N],k2[N],k3[N],k4[N],k5[N],k6[N],k7[N],k8[N];
int ksm(int x,int y) {
    int re=1;
    for(;y;y>>=1,x=1LL*x*x%mod) if(y&1) re=1LL*re*x%mod;
    return re;
}
int qm(int x) {return x>=mod?x-mod:x;}
void QAQ(int n) {for(RI i=0;i<n;++i) A[i]=B[i],B[i]=0;}
void NTT(int *a,int n,int x) {
    for(RI i=0;i<n;++i) if(rev[i]>i) swap(a[i],a[rev[i]]);
    for(RI i=1;i<n;i<<=1) {
        int gn=ksm(G,(mod-1)/(i<<1));
        for(RI j=0;j<n;j+=(i<<1)) {
            int t1,t2,g=1;
            for(RI k=0;k<i;++k,g=1LL*g*gn%mod) {
                t1=a[j+k],t2=1LL*g*a[j+i+k]%mod;
                a[j+k]=qm(t1+t2),a[j+i+k]=qm(t1+mod-t2);
            }
        }
    }
    if(x==1) return;
    reverse(a+1,a+n);
    for(RI i=0;i<n;++i) a[i]=1LL*a[i]*inv[n]%mod;
}
void getrev(int n)
    {for(RI i=0;i<n;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(len[n]-1));}
void getinv(int *a,int *b,int n) {
    if(n==1) {b[0]=ksm(a[0],mod-2),b[1]=0;return;}
    getinv(a,b,n>>1);int kn=n<<1;
    for(RI i=0;i<n;++i) k3[i]=a[i],k3[i+n]=b[i+n]=0;
    getrev(kn),NTT(k3,kn,1),NTT(b,kn,1);
    for(RI i=0;i<kn;++i) b[i]=1LL*(2LL-1LL*b[i]*k3[i]%mod+mod)%mod*b[i]%mod;
    NTT(b,kn,-1);
    for(RI i=n;i<kn;++i) b[i]=0;
}
void getsqrt(int *a,int *b,int n) {
    if(n==1) {b[0]=sqrt(a[0]),b[1]=0;return;}
    getsqrt(a,b,n>>1);int kn=n<<1;
    getinv(b,k1,n);
    for(RI i=0;i<n;++i) k2[i]=a[i],k2[i+n]=k1[i+n]=b[i+n]=0;
    getrev(kn),NTT(k1,kn,1),NTT(k2,kn,1),NTT(b,kn,1);
    for(RI i=0;i<kn;++i) b[i]=1LL*qm(b[i]+1LL*k2[i]*k1[i]%mod)*inv2%mod;
    NTT(b,kn,-1);
    for(RI i=n;i<kn;++i) b[i]=0;
}
void getJF(int *a,int *b,int n)
    {for(RI i=1;i<n;++i) b[i]=1LL*a[i-1]*inv[i]%mod;b[0]=0;}
void getdao(int *a,int *b,int n)
    {for(RI i=1;i<n;++i) b[i-1]=1LL*a[i]*i%mod,b[n-1]=0;}
void getln(int *a,int *b,int n) {
    getdao(a,k6,n),getinv(a,k7,n);
    int kn=n<<1;
    for(RI i=n;i<kn;++i) k6[i]=k7[i]=0;
    getrev(kn),NTT(k6,kn,1),NTT(k7,kn,1);
    for(RI i=0;i<kn;++i) k6[i]=1LL*k6[i]*k7[i]%mod;
    NTT(k6,kn,-1),getJF(k6,b,n);
    for(RI i=n;i<kn;++i) b[i]=0;
}
void getexp(int *a,int *b,int n) {
    if(n==1) {b[0]=1,b[1]=0;return;}
    getexp(a,b,n>>1);int kn=n<<1;
    getln(b,k4,n);
    for(RI i=0;i<n;++i) k5[i]=qm(a[i]-k4[i]+mod),k5[i+n]=b[i+n]=0;
    k5[0]=qm(k5[0]+1);
    getrev(kn),NTT(b,kn,1),NTT(k5,kn,1);
    for(RI i=0;i<kn;++i) b[i]=1LL*b[i]*k5[i]%mod;
    NTT(b,kn,-1);
    for(RI i=n;i<kn;++i) b[i]=0;
}
void getksm(int *a,int *b,int n,int K) {
    getln(a,k8,n);
    for(RI i=0;i<n;++i) k8[i]=1LL*k8[i]*K%mod;
    getexp(k8,b,n);
}
int main()
{
    n=read(),K=read()%mod;
    for(RI i=0;i<n;++i) A[i]=read();
    while(kn<n) kn<<=1,len[kn]=len[kn>>1]+1;
    len[kn<<1]=len[kn]+1;
    inv[1]=1;for(RI i=2;i<=(kn<<1);++i) inv[i]=1LL*(mod-mod/i)*inv[mod%i]%mod;
    getsqrt(A,B,kn),QAQ(kn);
    getinv(A,B,kn),QAQ(kn);
    getJF(A,B,n),QAQ(kn);
    getexp(A,B,kn),QAQ(kn);
    getinv(A,B,kn),QAQ(kn),A[0]=qm(A[0]+1);
    getln(A,B,kn),QAQ(kn),A[0]=qm(A[0]+1);
    getksm(A,B,kn,K),QAQ(kn);
    getdao(A,B,n);
    for(RI i=0;i<n;++i) printf("%d ",B[i]);
    return 0;
}

除法

关键在于除数与被除数的度数不同,就不能直接多项式求逆来做,很烦。
F R ( x ) 表示将 F ( x ) 翻转,也就是设 F ( x ) n 次多项式,那么 F R ( x ) = x n F ( 1 x )
设要求 H ( x ) = G ( x ) F ( x ) ,其中 G ( x ) n 次的, F ( x ) m 次的。设余数为 Q ( x ) ,那么 G ( x ) = H ( x ) F ( x ) + Q ( x )
等式的 x 全部变成 1 x ,然后乘上 x n
那么: x n G ( 1 x ) = x m H ( 1 x ) x n m F ( 1 x ) + x n m + 1 x m 1 Q ( 1 x )
也就是: G R ( x ) = H R ( x ) F R ( x ) + x n m + 1 Q R ( x )
发现在模 x n m + 1 意义下, Q ( x ) 每一项都是0。
所以 G R ( x ) = H R ( x ) F R ( x ) ( mod x n m + 1 )
真完美,多项式求逆可以解决本题啦。
放出洛谷P4512的代码:

#include<bits/stdc++.h>
using namespace std;
#define RI register int
int read() {
    int q=0;char ch=' ';
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
    return q;
}
const int mod=998244353,N=262150,G=3;
int n,m;
int A[N],revA[N],B[N],revB[N],rev[N],len[N];
int revC[N],kl[N];
int ksm(int x,int y) {
    int re=1;
    for(;y;y>>=1,x=1LL*x*x%mod) if(y&1) re=1LL*re*x%mod;
    return re;
}
void NTT(int *a,int n,int x) {
    for(RI i=0;i<n;++i) if(rev[i]>i) swap(a[i],a[rev[i]]);
    for(RI i=1;i<n;i<<=1) {
        int gn=ksm(G,(mod-1)/(i<<1));
        for(RI j=0;j<n;j+=(i<<1)) {
            int t1,t2,g=1;
            for(RI k=0;k<i;++k,g=1LL*g*gn%mod) {
                t1=a[j+k],t2=1LL*g*a[j+i+k]%mod;
                a[j+k]=(t1+t2)%mod,a[j+i+k]=(t1-t2+mod)%mod;
            }
        }
    }
    if(x==1) return;
    reverse(a+1,a+n);int inv=ksm(n,mod-2);
    for(RI i=0;i<n;++i) a[i]=1LL*a[i]*inv%mod;
}
void getrev(int k)
    {for(RI i=0;i<k;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(len[k]-1));}
void getinv(int *a,int *b,int n) {
    if(n==1) {b[0]=ksm(a[0],mod-2);return;}
    getinv(a,b,n>>1);
    int kn=n<<1;getrev(kn);
    for(RI i=0;i<n;++i) kl[i]=a[i],kl[i+n]=0;
    NTT(kl,kn,1),NTT(b,kn,1);
    for(RI i=0;i<kn;++i) b[i]=1LL*(2-1LL*b[i]*kl[i]%mod+mod)%mod*b[i]%mod;
    NTT(b,kn,-1);
    for(RI i=n;i<kn;++i) b[i]=0;
}
int main()
{
    n=read(),m=read();
    for(RI i=0;i<=n;++i) A[i]=revA[n-i]=read();
    for(RI i=0;i<=m;++i) B[i]=revB[m-i]=read();
    int kn=1;while(kn<n-m+1) kn<<=1,len[kn]=len[kn>>1]+1;
    len[kn<<1]=len[kn]+1;
    getinv(revB,revC,kn);
    kn<<=1;len[kn]=len[kn>>1]+1;
    for(RI i=n-m+1;i<kn;++i) revA[i]=revC[i]=0;
    getrev(kn),NTT(revC,kn,1),NTT(revA,kn,1);
    for(RI i=0;i<kn;++i) revA[i]=1LL*revA[i]*revC[i]%mod;
    NTT(revA,kn,-1);
    reverse(revA,revA+n-m+1);
    for(RI i=0;i<n-m+1;++i) printf("%d ",revA[i]);
    puts("");

    for(RI i=n-m+1;i<kn||i<=n;++i) revA[i]=0;
    while(kn<=n+1) kn<<=1,len[kn]=len[kn>>1]+1;
    getrev(kn);
    NTT(revA,kn,1),NTT(B,kn,1);
    for(RI i=0;i<kn;++i) revA[i]=1LL*revA[i]*B[i]%mod;
    NTT(revA,kn,-1);
    for(RI i=0;i<m;++i) printf("%d ",(A[i]-revA[i]+mod)%mod);
    puts("");
    return 0;
}

插值

插值是获得了 n + 1 个多项式的点值表达后将该 n 次多项式还原的操作。
方法是拉格朗日插值多项式:

y = i = 0 n j i ( x x j ) j i ( x i x j ) y i

带入该多项式可发现显然 f ( x i ) = y i
emmm….复杂度是 O ( n 2 ) 的?不过那是 x 值已知的情况,一定要写出多项式的话,恕本蒟蒻愚钝只会分治NTT搞…… O ( n 2 l o g n ) 的……
放出洛谷P4781的代码

#include<bits/stdc++.h>
using namespace std;
#define RI register int
const int mod=998244353,N=2005;
int n,K,ans,xx[N],yy[N];
int ksm(int x,int y) {
    int re=1;
    for(;y;y>>=1,x=1LL*x*x%mod) if(y&1) re=1LL*re*x%mod;
    return re;
}
int qm(int x) {return x>=mod?x-mod:x;}
int main()
{
    scanf("%d%d",&n,&K);
    for(RI i=1;i<=n;++i) scanf("%d%d",&xx[i],&yy[i]);
    for(RI i=1;i<=n;++i) {
        int s1=1,s2=1;
        for(RI j=1;j<=n;++j) {
            if(i==j) continue;
            s1=1LL*s1*qm(K-xx[j]+mod)%mod;
            s2=1LL*s2*qm(xx[i]-xx[j]+mod)%mod;
        }
        ans=qm(ans+1LL*s1*ksm(s2,mod-2)%mod*yy[i]%mod);
    }
    printf("%d\n",ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/litble/article/details/81749788