LOJ 2409「THUPC 2017」小 L 的计算题 / Sum

思路

和玩游戏一题类似
定义\(A_k(x)=\sum_{i=0}^\infty a_k^ix^i=\frac{1}{1-a_kx}\)

\(\ln 'x\)代替\(\frac{1}{x}\)

所以就是求
\[ f(x)=\sum_{i=1}^n \ln'(1-a_ix) \]
这样没法快速计算

所以再设\(G(x)=\sum _{i=1}^n (ln(1-a_ix))'\)

所以
\[ G(x)=\sum_{i=1}^n\frac{-a_i}{1-a_ix} \]
所以
\[ f(x)=-xg(x)+n \]

\[ G(x)=\ln'(\prod_{i=1}^n (1-a_ix)) \]

然后上分治+NTT就可以在\(O(n\log^2n)\)的时间内解决了

代码

#include <cstdio>
#include <algorithm>
#include <cstring>
#define int long long
using namespace std;
const int MOD = 998244353;
const int G = 3;
const int invG = 332748118;
const int MAXN = 2000000;
int pow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)
            ans=(1LL*ans*a)%MOD;
        a=(1LL*a*a)%MOD;
        b>>=1;
    }
    return ans;
}
int rev[MAXN],inv_val[MAXN];
void cal_rev(int n,int lim){
    for(int i=0;i<n;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(lim-1));
}
void cal_inv(int n){
    inv_val[0]=0;
    inv_val[1]=1;
    for(int i=2;i<=n;i++)
        inv_val[i]=(1LL*(MOD-MOD/i)*inv_val[MOD%i])%MOD;
}
void NTT(int *a,int opt,int n,int lim){
    for(int i=0;i<n;i++)
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
    for(int i=2;i<=n;i<<=1){
        int len=i/2,tmp=pow((opt)?G:invG,(MOD-1)/i);
        for(int j=0;j<n;j+=i){
            int arr=1;
            for(int k=j;k<j+len;k++){
                int t=(1LL*a[k+len]*arr)%MOD;
                a[k+len]=(a[k]-t+MOD)%MOD;
                a[k]=(a[k]+t)%MOD;
                arr=(1LL*arr*tmp)%MOD;
            }
        }
    }
    if(!opt){
        int invN = pow(n,MOD-2);
        for(int i=0;i<n;i++)
            a[i]=(1LL*a[i]*invN)%MOD;
    }
}
void mul(int *a,int *b,int &at,int bt){
    int midlen=1,midlim=0;
    while((midlen)<(at+bt+2))
        midlen<<=1,midlim++;
    cal_rev(midlen,midlim);
    static int tmp[MAXN];
    for(int i=0;i<midlen;i++)
        tmp[i]=b[i];
    NTT(a,1,midlen,midlim);
    NTT(tmp,1,midlen,midlim);
    for(int i=0;i<midlen;i++)
        a[i]=(1LL*a[i]*tmp[i])%MOD;
    NTT(a,0,midlen,midlim);
    at+=bt;
    for(int i=0;i<=at;i++)
        tmp[i]=0;
    for(int i=at+1;i<midlen;i++)
        a[i]=0,tmp[i]=0;
}
void inv(int *a,int *b,int dep,int &midlen,int &midlim){
    if(dep==1){
        b[0]=pow(a[0],MOD-2);
        return;
    }
    inv(a,b,(dep+1)>>1,midlen,midlim);
    static int tmp[MAXN];
    while((dep<<1)>midlen)
        midlen<<=1,midlim++;
    for(int i=0;i<dep;i++)
        tmp[i]=a[i];
    for(int i=dep;i<midlen;i++)
        tmp[i]=0;
    cal_rev(midlen,midlim);
    NTT(tmp,1,midlen,midlim);
    NTT(b,1,midlen,midlim);
    for(int i=0;i<midlen;i++)
        b[i]=(1LL*b[i]*(2-1LL*tmp[i]*b[i]%MOD+MOD)%MOD)%MOD;
    NTT(b,0,midlen,midlim);
    for(int i=dep;i<midlen;i++)
        b[i]=0;
}
void qd(int *a,int &t){
    for(int i=0;i<t;i++)
        a[i]=1LL*a[i+1]*(i+1)%MOD;
    a[t]=0;
    t--;
}
void jf(int *a,int &t){
    t++;
    for(int i=t;i>0;i--){
        a[i]=1LL*a[i-1]*inv_val[i]%MOD;
    }
    a[0]=0;
}
void ln(int *a,int *b,int n){
    static int tmp[MAXN];
    for(int i=0;i<n;i++)
        tmp[i]=0,b[i]=a[i];
    int midlen=1,midlim=0;
    inv(a,tmp,n,midlen,midlim);
    int t=n;
    qd(b,t);
    mul(b,tmp,t,n);
    jf(b,t);
    for(int i=0;i<n;i++)
        tmp[i]=0;
    for(int i=n;i<midlen;i++)
        b[i]=tmp[i]=0;    
}
int n,a[MAXN],b[MAXN],c[MAXN];
int cnt=0,barrel[40][MAXN];
void solve(int l,int r,int *val,int &len){
    if(l==r){
        val[0]=1;
        val[1]=MOD-a[l];
        len=1;
        return;
    }
    int *la=barrel[cnt++],*ra=barrel[cnt++];
    int num=cnt,lenl,lenr;
    // printf("num=%lld\n",cnt);
    int mid=(l+r)>>1;
    solve(l,mid,la,lenl);
    solve(mid+1,r,ra,lenr);
    int midlen=1,midlim=0;
    while(midlen<(lenl+lenr+2))
        midlen<<=1,midlim++;
    cal_rev(midlen,midlim);
    NTT(la,1,midlen,midlim);
    NTT(ra,1,midlen,midlim);
    for(int i=0;i<midlen;i++)
        val[i]=(1LL*la[i]*ra[i])%MOD;
    NTT(val,0,midlen,midlim);
    for(int i=0;i<midlen;i++)
        la[i]=ra[i]=0;
    len=lenl+lenr;
    cnt=num-2;
}
signed main(){
    // freopen("test.in","r",stdin);
    // freopen("test.out","w",stdout);
    int T;
    scanf("%lld",&T);
    while(T--){
        // printf("ok %lld\n",T);
        scanf("%lld",&n);
        cal_inv(n+10);
        for(int i=1;i<=n;i++)
            scanf("%lld",&a[i]);
        int len=0;
        solve(1,n,b,len);
        // for(int i=0;i<=n;i++)
        //     printf("%lld ",b[i]);
        // printf("\n");
        ln(b,c,n+1);
        // for(int i=0;i<=n;i++)
        //     printf("%lld ",c[i]);
        // printf("\n");
        int t=n;
        qd(c,t);
        for(int i=n;i>=1;i--)
            c[i]=MOD-c[i-1];
        c[0]=n;
        // for(int i=0;i<=n;i++)
        //     printf("%lld ",c[i]);
        // printf("\n");
        // for(int i=1;i<=n;i++)
        //     printf("%lld ",c[i]);
        // printf("\n");
        int ans=0;
        for(int i=1;i<=n;i++)
            ans^=c[i];
        printf("%lld\n",ans);
        for(int i=0;i<=n;i++)
            a[i]=b[i]=c[i]=0;
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/dreagonm/p/10780281.html