BZOJ 图的价值 (ntt,第二类斯特林数)

Description

“简单无向图”是指无重边、无自环的无向图(不一定连通)。
一个带标号的图的价值定义为每个点度数的k次方的和。
给定n和k,请计算所有n个点的带标号的简单无向图的价值之和。
因为答案很大,请对998244353取模输出。

Input

第一行包含两个正整数n,k(1<=n<=10^9,1<=k<=200000)。

Output

 输出一行一个整数,即答案对998244353取模的结果。

Sample Input

6 5

Sample Output

67584000
#include<bits/stdc++.h>
#define ms(x) memset(x,0,sizeof(x))
#define sws ios::sync_with_stdio(false)
using namespace std;
typedef long long ll;
const int maxn=4e6+5;
const double pi=acos(-1.0);
const ll mod=998244353;///通常情况下的模数,
const ll g=3;///模数的原根998244353,1004535809,469762049

ll qpow(ll a,ll n,ll p){
    ll ans=1;
    while(n){
        if(n&1) ans=ans*a%p;
        n>>=1;
        a=a*a%p;
    }
    return ans;
}
int rev[maxn];
void ntt(ll a[],int n,int len,int pd){
    rev[0]=0;
    for(int i=1;i<n;i++){
        rev[i]=(rev[i>>1]>>1 | ((i&1)<<(len-1)));
        if(i<rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int mid=1;mid<n;mid<<=1){
        ll wn=qpow(g,(mod-1)/(mid*2),mod);///原根代替单位根
        if(pd==-1) wn=qpow(wn,mod-2,mod);///逆变换则改成逆元
        for(int j=0;j<n;j+=2*mid){
            ll w=1;
            for(int k=0;k<mid;k++){
                ll x=a[j+k],y=w*a[j+k+mid]%mod;
                a[j+k]=(x+y)%mod;
                a[j+k+mid]=(x-y+mod)%mod;
                w=w*wn%mod;
            }
        }
    }
    if(pd==-1){
        ll inv=qpow(n,mod-2,mod);
        for(int i=0;i<n;i++){
            a[i]=a[i]*inv%mod;

        }
    }
}
ll a[maxn],b[maxn],c[maxn];
void solve(int n,int m){
    int len=0,up=1;
    while(up<=n+m) up<<=1,len++;
    ntt(a,up,len,1);
    ntt(b,up,len,1);
    for(int i=0;i<up;i++) c[i]=1ll*a[i]*b[i]%mod;
    ntt(c,up,len,-1);
}
ll fa[maxn];
ll suf[maxn];
int main(){
    ll n,k;
    sws;
    cin>>n>>k;
    fa[0]=1;
    for(ll i=0;i<=k;i++){
        int t=1;
        if(i!=0) fa[i]=fa[i-1]*i%mod;
        if(i&1) t=-1;
        a[i]=t*qpow(fa[i],mod-2,mod)%mod;
        a[i]=(a[i]+mod)%mod;
        b[i]=qpow(i,k,mod)*qpow(fa[i],mod-2,mod)%mod;
    }
    solve(k,k);
    ll ans=0;
    ll up=min(k,n-1);
    ll suf=1;
    for(int i=0;i<=up;i++){
        ans=(ans+suf*c[i]%mod*qpow(2,n-1-i,mod)%mod)%mod;
        suf=suf*(n-1-i)%mod;
    }
    ans=ans*qpow(2,(n-1)*(n-2)/2,mod)%mod*n%mod;
    cout<<ans<<endl;

}

猜你喜欢

转载自www.cnblogs.com/azznaz/p/11525155.html