「LOJ6570 毛毛虫计数」 - 生成函数

LOJ6570 毛毛虫计数

tags:生成函数,多项式

题意

hsezoi 巨佬 olinr 喜欢 van 毛毛虫,他定义毛毛虫是一棵树,满足树上存在一条树链,使得树上所有点到这条树链的距离最多为 \(1\)。给定 \(n\) \((n\le10^5)\) 。现在请你求出 \(n\) 个点、有标号的毛毛虫的数量。答案对 \(998244353\) 取模。

题解

构造生成函数

对于毛毛虫直径中间的一个节点,大小为 i 总共有 i 种放法,指数型生成函数是

\[ A(x)=\sum_{i=1}^\infty\frac{ix^i}{i!} \]

对于与直径两端点相连的一个节点,强制至少挂一个节点,指数型生成函数是

\[ B(x)=\sum_{i=2}^\infty\frac{ix^i}{i!} \]

然后结果就是 \((A(x)^0+A(x)^1+\cdots)B(x)^2=\frac {B(x)^2}{1-A(x)}\)

输出:

\[ \frac{n!}2[x^n]\frac {B(x)^2}{1-A(x)} \]

注意要特判 n = 2

顺便放下我的多项式板子,需要的时候可以拉

#include<cstdio>
#include<vector>
//#define debug(...) fprintf(stderr,__VA_ARGS__)
#define debug(...) ((void)0)
typedef std::vector<int> poly;
const int P=998244353;
int fpow(int a,int b){int res=1;for(;b;b>>=1,a=1ll*a*a%P)if(b&1)res=1ll*res*a%P;return res;} 
void pt(const poly&a){for(int i=0;i<(int)a.size();++i)debug("%d ",a[i]);debug("\n");}
int getlim(int n){int x=1;while(x<=n)x<<=1;return x;}
void ntt(poly&a,int g,int lim){
    a.resize(lim);
    for(int i=0,j=0;i<lim;++i){
        if(i<j)std::swap(a[i],a[j]);
        for(int k=lim>>1;(j^=k)<k;k>>=1);
    }
    poly w(lim>>1);w[0]=1;
    for(int i=1;i<lim;i<<=1){
        for(int j=1,wn=fpow(g,(P-1)/(i<<1));j<i;++j)w[j]=1ll*w[j-1]*wn%P;
        for(int j=0;j<lim;j+=i<<1)
            for(int k=0;k<i;++k){
                int x=a[j+k],y=1ll*a[i+j+k]*w[k]%P;
                a[j+k]=(x+y)%P,a[i+j+k]=(x-y+P)%P;
            }
    }
    if(g==332748118)for(int i=0,I=fpow(lim,P-2);i<(int)a.size();++i)a[i]=1ll*a[i]*I%P;
}
poly pmul(poly a,poly b){
    int need=(int)a.size()+b.size()-1,lim=getlim(need);
    ntt(a,3,lim),ntt(b,3,lim);
    for(int i=0;i<lim;++i)a[i]=1ll*a[i]*b[i]%P;
    ntt(a,332748118,lim);
    return a.resize(need),a;
}
poly padd(poly a,poly b){
    if(a.size()<b.size()){
        for(int i=0;i<(int)a.size();++i)(b[i]+=a[i])%=P;
        return b;
    }else{
        for(int i=0;i<(int)b.size();++i)(a[i]+=b[i])%=P; 
        return a;
    }
}
poly pinv(const poly&a,int n=-1){
    if(n==-1)n=a.size();
    if(n==1)return poly(1,fpow(a[0],P-2));
    poly b=pinv(a,(n+1)>>1),tmp=poly(a.begin(),a.begin()+n);
    int lim=getlim(n*2-2);
    ntt(b,3,lim),ntt(tmp,3,lim);
    for(int i=0;i<lim;++i)b[i]=(2-1ll*b[i]*tmp[i]%P+P)%P*b[i]%P;
    ntt(b,332748118,lim);
    return b.resize(n),b;
}
poly pdao(const poly&a){
    poly b((int)a.size()-1);
    for(int i=1;i<(int)a.size();++i)b[i-1]=1ll*a[i]*i%P;
    return b;
}
poly pji(const poly&a){
    poly b((int)a.size()+1);
    for(int i=0;i<(int)a.size();++i)b[i+1]=1ll*a[i]*fpow(i+1,P-2)%P;
    return b;
}
poly pln(const poly&a){
    poly b(pmul(pdao(a),pinv(a)));
    b.resize((int)a.size()-1);
    return pji(b);
}
poly pexp(const poly&a,int n=-1){
    if(n==-1)n=a.size();
    if(n==1)return poly(1,1);
    poly b=pexp(a,(n+1)>>1),c(b);
    c.resize(n),c=pln(c),--c[0];
    for(int i=0;i<n;++i)c[i]=(a[i]-c[i]+P)%P;
    poly d(pmul(b,c));
    return d.resize(n),d;
}
const int N=100005;
int n,fac[N],inv[N];
int main(){
    fac[0]=fac[1]=inv[0]=inv[1]=1;
    for(int i=2;i<N;++i)fac[i]=1ll*fac[i-1]*i%P,inv[i]=1ll*(P-P/i)*inv[P%i]%P;
    for(int i=2;i<N;++i)inv[i]=1ll*inv[i]*inv[i-1]%P;
    scanf("%d",&n);if(n<=2)return puts("1"),0;
    poly A(n+1),B(n+1);A[0]=1;
    for(int i=1;i<=n;++i)A[i]=(P-1ll*i*inv[i]%P)%P;
    for(int i=2;i<=n;++i)B[i]=1ll*i*inv[i]%P;
    A=pinv(A),B=pmul(B,B),B.resize(n+1),A=pmul(A,B),A.resize(n+1);
    printf("%lld\n",(n+1ll*A[n]*fac[n]%P*((P+1)>>1)%P)%P);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/xay5421/p/LOJ6570.html