LG4512 【模板】多项式除法

P4512 【模板】多项式除法

题目描述

给定一个 $n$ 次多项式 $F(x)$ 和一个 $m$ 次多项式 $G(x)$ ,请求出多项式 $Q(x)$, $R(x)$,满足以下条件:

  • $Q(x)$ 次数为 $n-m$,$R(x)$ 次数小于 $m$
  • $F(x) = Q(x) * G(x) + R(x)$

所有的运算在模 $998244353$ 意义下进行。

输入输出格式

输入格式:

第一行两个整数 $n$,$m$,意义如上。
第二行 $n+1$ 个整数,从低到高表示 $F(x)$ 的各个系数。
第三行 $m+1$ 个整数,从低到高表示 $G(x)$ 的各个系数。

输出格式:

第一行 $n-m+1$ 个整数,从低到高表示 $Q(x)$ 的各个系数。
第二行 $m$ 个整数,从低到高表示 $R(x)$ 的各个系数。
如果 $R(x)$ 不足 $m-1$ 次,多余的项系数补 $0$。

输入输出样例

输入样例#1: 复制
5 1
1 9 2 6 0 8
1 7
输出样例#1: 复制
237340659 335104102 649004347 448191342 855638018
760903695

说明

对于所有数据,$1 \le m < n \le 10^5$,给出的系数均属于 $[0, 998244353) \cap \mathbb{Z}$。

题解

仍然在跟边界交手。
\[ f_R(x)\equiv (q_R\times g_R)(x) \mod x^{n-m+1} \]
翻转过后,因为这个原始的式子,所以都只能取\(n-m+1\)项。\(g\)求逆前就该保留\(n-m+1\)项。
\[ q_R(x)\equiv (f_R\times g_R^{-1})(x) \mod x^{n-m+1} \]
然后\(q_R\)翻转一下就得到\(q\)。回代\(r(x)=f(x)-(q\times g)(x)\)即可解出\(r\),只需要算前\(m\)项即可。

\(T(n)=O(n\log n)\)

#include<bits/stdc++.h>
#define il inline
#define co const
template<class T>T read(){
    T data=0,w=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-') w=-w;
    for(;isdigit(ch);ch=getchar()) data=data*10+ch-'0';
    return data*w;
}
template<class T>il T read(T&x) {return x=read<T>();}
typedef long long LL;
using namespace std;
typedef vector<int> polynomial;

co int mod=998244353,g=3,g_inv=332748118;
il int add(int a,int b){
    return (a+=b)>=mod?a-mod:a;
}
il int mul(int a,int b){
    return (LL)a*b%mod;
}
int fpow(int a,int b){
    int ans=1;
    for(;b;b>>=1,a=mul(a,a))
        if(b&1) ans=mul(ans,a);
    return ans;
}
void num_trans(polynomial&a,int inverse){
    int limit=a.size(),len=log2(limit);
    static vector<int> bit_rev;
    if(bit_rev.size()!=limit){
        bit_rev.resize(limit);
        for(int i=0;i<limit;++i) bit_rev[i]=bit_rev[i>>1]>>1|(i&1)<<(len-1);
    }
    for(int i=0;i<limit;++i)if(i<bit_rev[i]) swap(a[i],a[bit_rev[i]]);
    for(int step=1;step<limit;step<<=1){
        int gn=fpow(inverse==1?g:g_inv,(mod-1)/(step<<1));
        for(int even=0;even<limit;even+=step<<1){
            int odd=even+step,gk=1;
            for(int k=0;k<step;++k,gk=mul(gk,gn)){
                int t=mul(gk,a[odd+k]);
                a[odd+k]=add(a[even+k],mod-t),a[even+k]=add(a[even+k],t);
            }
        }
    }
    if(inverse==-1){
        int lim_inv=fpow(limit,mod-2);
        for(int i=0;i<limit;++i) a[i]=mul(a[i],lim_inv);
    }
}
polynomial poly_inv(polynomial a,int n){
    polynomial b[2];
    b[0].push_back(fpow(a[0],mod-2));
    if(n==1) return b[0]; // edit 2
    a.resize(1<<int(ceil(log2(n))+1));
    int limit,len;
    for(limit=2,len=1;limit<n;limit<<=1,++len){
        polynomial a1(a.begin(),a.begin()+limit);
        a1.resize(limit<<1),num_trans(a1,1);
        b[(len&1)^1].resize(limit<<1),num_trans(b[(len&1)^1],1);
        b[len&1].resize(limit<<1);
        for(int i=0;i<limit<<1;++i) b[len&1][i]=mul(add(2,mod-mul(a1[i],b[(len&1)^1][i])),b[(len&1)^1][i]);
        num_trans(b[len&1],-1),b[len&1].resize(limit);
    }
    assert(a.size()==limit<<1),num_trans(a,1);
    b[(len&1)^1].resize(limit<<1),num_trans(b[(len&1)^1],1);
    b[len&1].resize(limit<<1);
    for(int i=0;i<limit<<1;++i) b[len&1][i]=mul(add(2,mod-mul(a[i],b[(len&1)^1][i])),b[(len&1)^1][i]);
    num_trans(b[len&1],-1),b[len&1].resize(n);
    return b[len&1];
}
polynomial poly_div(polynomial f,polynomial g){ // return the quotient
    int n=f.size()-1,m=g.size()-1;
    reverse(g.begin(),g.end()),g.resize(n-m+1),g=poly_inv(g,n-m+1); // edit 1:access partially
    reverse(f.begin(),f.end()),f.resize(n-m+1);
    int limit=1<<int(ceil(log2(2*(n-m)+1)));
    f.resize(limit),g.resize(limit);
    num_trans(f,1),num_trans(g,1);
    for(int i=0;i<limit;++i) f[i]=mul(f[i],g[i]);
    num_trans(f,-1),f.resize(n-m+1);
    return reverse(f.begin(),f.end()),f;
}
int main(){
//  freopen("LG4512.in","r",stdin);
    int n=read<int>(),m=read<int>();
    polynomial f(n+1),g(m+1);
    for(int i=0;i<=n;++i) read(f[i]);
    for(int i=0;i<=m;++i) read(g[i]);
    polynomial q=poly_div(f,g);
    for(int i=0;i<=n-m;++i) printf("%d ",q[i]);
    puts("");
    // poly_mod
    int limit=1<<int(ceil(log2(n)));
    g.resize(limit),q.resize(limit);
    num_trans(g,1),num_trans(q,1);
    for(int i=0;i<limit;++i) g[i]=mul(g[i],q[i]);
    num_trans(g,-1);
    for(int i=0;i<m;++i) printf("%d ",add(f[i],mod-g[i]));
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/autoint/p/11116074.html
今日推荐