[NOWCODER] myh的超级多项式

题面

已知$f_i=(\sum_{j=1}^ka_j{v_j}^i )\bmod 1004535809$

给定$v_1,v_2,\ldots,v_k,f_1,f_2,\ldots f_k$

求$f_n$

思路

我们考虑构造一个递推式,使得:

$f_n=\sum_{i=1}^k c_i f_{n-i}$

我们把这个$f_n$挪到右边来,令$c_0=1$,得到:

$\sum_{i=0}^k c_i f_{n-i} =0$

即:

$\sum_{i=0}^k c_i \sum_{j=1}^k a_j v_j^{n-i}=0$

这个式子的一个充分条件(可行条件)

$\forall j \in [1,k] \sum_{i=0}^k c_i a_j v_j^{n-i}=0$

把$a_j$挪到前面去,除掉一部分$v_j$的幂,得到这个式子:

$\forall j \in [1,k] \sum_{i=0}^k c_i v_j^{k-i}=0$

令$F(x)=\sum c_{k-i} x^i$,那么我们发现${v}$数组是$F(x)$的所有0点

又因为$c_0=-1$,所以$F(x)=-\prod_{i=1}^k (x-v_i)$

分治FFT求出$F(x)$,然后用$O((n-k)k)$递推(不会TLE)得到$f_n$即可

Code

代码里有一个技巧

因为一段区间得到的n+1个系数的多项式的最高次项一定是1,所以我们可以不保存他

这样分治FFT用长度为n的数组就能保存了

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define MOD 1004535809
#define ll long long
using namespace std;
inline int read(){
    int re=0,flag=1;char ch=getchar();
    while(!isdigit(ch)){
        if(ch=='-') flag=-1;
        ch=getchar();
    }
    while(isdigit(ch)) re=(re<<1)+(re<<3)+ch-'0',ch=getchar();
    return re*flag;
}
ll qpow(ll a,ll b){
    ll re=1;
    while(b){
        if(b&1) re=re*a%MOD;
        a=a*a%MOD;b>>=1;
    }
    return re;
}
ll add(ll a,ll b){
    a+=b;
    return ((a>=MOD)?a-MOD:a);
}
ll dec(ll a,ll b){
    a-=b;
    return ((a<0)?a+MOD:a);
}
ll g=3,ginv;
namespace NTT{
    int lim,cnt,r[400010];
    ll A[400010],B[400010];
    void ntt(ll *a,ll type){
        int i,j,k,mid;ll x,y,w,wn,inv;
        for(i=0;i<lim;i++) if(i<r[i]) swap(a[i],a[r[i]]);
        for(mid=1;mid<lim;mid<<=1){
            wn=qpow(((~type)?g:ginv),(MOD-1)/(mid<<1));
            for(j=0;j<lim;j+=(mid<<1)){
                w=1;
                for(k=0;k<mid;k++,w=w*wn%MOD){
                    x=a[j+k];y=a[j+k+mid]*w%MOD;
                    a[j+k]=add(x,y);
                    a[j+k+mid]=dec(x,y);
                }
            }
        }
        if(~type) return;
        inv=qpow(lim,MOD-2);
        for(i=0;i<lim;i++) a[i]=a[i]*inv%MOD;
    }
    void init(int n){
        int i;
        lim=1;cnt=0;
        while(lim<=n) lim<<=1,cnt++;
        for(i=0;i<lim;i++) r[i]=((r[i>>1]>>1)|((i&1)<<(cnt-1))),A[i]=B[i]=0;
    }
}
void mul(){
    using namespace NTT;
    ntt(A,1);ntt(B,1);int i;
    for(i=0;i<lim;i++) A[i]=A[i]*B[i]%MOD;
    ntt(A,-1);
}
ll c[100010];//黑科技数组
int n,k;ll v[100010],f[100010];
void solve(int l,int r){
    if(l==r){
        c[l]=MOD-v[l];
        return;
    }
    int mid=(l+r)>>1,i;
    solve(l,mid);solve(mid+1,r);
    using namespace NTT;
    init(r-l+1);
    for(i=0;i<=mid-l;i++) A[i]=c[i+l];
    for(i=0;i<r-mid;i++) B[i]=c[i+mid+1];
    A[mid-l+1]=B[r-mid]=1;//把没记录的1加上
    mul();
    for(i=0;i<=r-l;i++) c[l+i]=A[i];//这里不保存1
}
int main(){
    n=read();k=read();int i,j;
    g=3;ginv=qpow(3,MOD-2);
    for(i=1;i<=k;i++) v[i]=read();
    for(i=1;i<=k;i++) f[i]=read();
    solve(1,k);
    for(i=0;i<k;i++) c[i]=c[i+1];
    c[k]=1;
    for(i=0;i<=k;i++) if(c[i]) c[i]=MOD-c[i];
    for(i=0;i<=k/2;i++) swap(c[i],c[k-i]);
    for(i=k+1;i<=n;i++){
        ll w=0;
        for(j=1;j<=k;j++) w+=c[j]*f[i-j]%MOD;
        f[i]=w%MOD;
    }
    printf("%lld\n",f[n]);
}

猜你喜欢

转载自www.cnblogs.com/dedicatus545/p/9728908.html