WC2019 tree

WC2019唯一一道正常的题,考场上没什么想法,也只拿到了暴力分。搞了一天终于做完了。

前置知识:purfer序,多项式exp或分治FTT。

对于\(type=0\)的,随便维护下,算下联通块即可。

对于\(type=1\)的,如果有\(k\)个联通块,贡献就是\(y^k\),等价于\((y-1+1)^k\),等价于\(\sum_{i=0}^{k}{k \choose i}(y-1)^i\),就是拆出一个\(1\),然后二项式展开。

原本的贡献是\(y^n\)现在每有一条边同时出现在两颗树上,那么贡献就要乘上\(y^{-1}\)

从此以后题道的所有的\(y\)都是\(y^{-1}\)

这样有什么用呢?这个式子的含义是枚举\(n​\)个点的所有边的子集\(E​\),如果\(E​\)同时是两棵树的边的子集,那么贡献就是\((y-1)^{|E|}​\),否则是\(0​\)。这样组和数已经算在里面了,每个子集的贡献最后合成了一个方案的贡献。

现在我们已知第一颗树,我们钦定一些边是两颗树上都有的边,并且假设这些边组成的联通块把数分成了\(m\)个联通块,那么这棵树的方案数为:

\(\sum_{\sum_{i=1}^{m}d_i==2(m-1)}(m-2)! \prod \frac{a_i^{d_i}}{(d_i-1)!}\)

\(d_i​\)是度数,\(a_i​\)是联通块里点的个数。

\[\frac{(m-2)!}{\prod (d_i-1)!}\]就是组合数选点放在\(prufer\)序里。\(a_i^{d_i}\)就是枚举每个从这个联通块连出去的边连在那个点上。

\(d_i​\)换成\(d_i-1​\),式子变成:

\(\sum_{\sum_{i=1}^{m}d_i==m-2}(m-2)! \prod \frac{a_i^{d_i+1}}{d_i!}\)

等价于:

\(n^{m-2}\prod_{i=1}^{m}a_i\)

现在要证明

\(\sum_{\sum_{i=1}^{m}d_i==m-2}(m-2)! \prod \frac{a_i^{d_i}}{d_i!}==n^{m-2}\)

现在我们把\(d_i​\)看成物品,除了组合选点放在\(prufer​\)序中,\(d_i​\)还要乘上\(a_i​\)。现在你直接考虑每一个\(d_i​\),它的贡献。其实他们直接互补影响,所以一个\(d_i​\)的贡献可以是任意一个\(a_i​\),所以总贡献就是\({\sum a_i}^{m-2}​\),也就是\(n^{m-2}​\)

现在我们只需要算出所有方案的贡献和即可。但是我们没有办法记录\(a_i\),所以有一个巧妙的转化,在一个大小为\(a_i\)的联通块内选一个点的方案数也是\(a_i\),所以我们只需要设\(dp[u][0/1]\),表示\(u\)\(u\)的子树的贡献和,其中\(u\)所在的联通块是否选择了一个点即可,选择一条边在边集里的贡献是\((y-1)\),不在的话意味着产生了一个新的联通块,贡献为\(n\)

对于\(ty=2​\)的,我们两棵树都不知道,所以我们直接假设至少有有\(l​\)条边在两棵树中,令\(m=n-l​\)就是有\(l​\)个联通块。

\(\sum_{\sum_{i=1}^{m}a_i ==n} \frac{n!\prod_{i=1}^{m}\frac{a_i^{a_i-2}}{a_i!}}{m!}*(n^{m-2}\prod a_i)^2\)

枚举每个联通块里有多少个点,\(a_i^{a_i-2}\)就是每个联通块里都自己构成一棵树,下面除掉\(m!\)是因为联通块之间无顺序,后面就是两棵树的方案。

带上容斥系数

\(\sum_{\sum_{i=1}^{m}a_i==n-2}(y-1)^{n-m}n^{2(m-2)}\prod_{i=1}^{m}{(\sum_{j=1}^{i}a_j)-1\choose a_i-1}a_i^{a_i}\)

大多数都是直接从上面的式子化出来的,只有一个地方:

\(\frac{n!}{m!\prod a_i!}=\prod_{i=1}^{m}{(\sum_{j=1}^{i}a_j)-1\choose a_i-1}\)

右边的意思就是每次从当前所有点数中取出\(a_i\)个,上下都要减1是因为是无序的,为了保证不算重,要强制某个特殊的点在这个联通块里,比如编号最小的那个。

考虑动态规划,设\(f[k]\)\(k\)个点的贡献之和,答案就是\(f[n]\)

初始化\(f[0]=(y-1)^n*n!*n^{-4}\),但是真正用生成函数时这些东西要最后乘上去。

转移很简单,就是枚举最后一个联通块的大小。

\(f[k]=(y-1)^{-1}*n^2*\sum_{i=1}^{k}{k-1 \choose i-1}i^i*f[k-i]\)

\(=(y-1)^{-1}*n^2*\sum_{i=1}^{k}\frac{(k-1)!*i^i}{(i-1)!(k-i)!}f[k-i]\)

两边同除\(k!\),这样我们求的东西从\(f[k]\)变成了\(\frac{f[k]}{k!}\),最后在乘回来即可,后面的\(f[k]\)均为\(\frac{f[k]}{k!}\)

\(f[k]=(y-1)^{-1}*n^2*\sum_{i=1}^{k}\frac{f[k-i]*i^i}{k*(i-1)!}\)

好像可以直接分治\(FFT​\)了。

我们令\(f=f[i]x^i\),\(g=\frac{i^i}{(i-1)!}*(y-1)^{-1}*n^2\)

因为右边等式下面有个\(\frac{1}{k}\),我们把\(k\)乘到右边去,但是这样第i项的系数就到第\(i+1\)项去了,为了满足此递推式,我们要求导。

\(f'x=fg\)

\(\frac{f'}{f}=\frac{g}{x}\)

\(ln(f)=\int \frac{g}{x} dx\)

然后你在推一推,发现\(\int \frac{g}{x} dx=\sum n^2*(y-1)^{-1}*\frac{i^i}{i!}\)

直接对这个多项式\(exp\)即可。

#include<bits/stdc++.h>
using namespace std;
typedef int sign;
typedef long long ll;
#define For(i,a,b) for(register sign i=(sign)(a);i<=(sign)(b);++i)
#define Fordown(i,a,b) for(register sign i=(sign)(a);i>=(sign)(b);--i)
template<typename T>bool cmax(T &a,T b){return (a<b)?a=b,1:0;}
template<typename T>bool cmin(T &a,T b){return (a>b)?a=b,1:0;}
template<typename T>T read()
{
    T ans=0,f=1;
    char ch=getchar();
    while(!isdigit(ch)&&ch!='-')ch=getchar();
    if(ch=='-')f=-1,ch=getchar();
    while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch-'0'),ch=getchar();
    return ans*f;
}
template<typename T>void write(T x,char y)
{
    if(x==0)
    {
        putchar('0'),putchar(y);
        return;
    }
    if(x<0)
    {
        putchar('-');
        x=-x;
    }
    static char wr[20];
    int top=0;
    for(;x;x/=10)wr[++top]=x%10+'0';
    while(top)putchar(wr[top--]);
    putchar(y);
}
void file()
{
    freopen("tree.in","r",stdin);
    freopen("tree.out","w",stdout);
}
const int N=1e5+5;
int n,m,ty;
void input()
{
    n=read<int>(),m=read<int>(),ty=read<int>();
}
const int mo=998244353;
int power(int x,int y)
{
    int res=1;
    for(;y;x=1ll*x*x%mo,y>>=1)if(y&1)res=1ll*res*x%mo;
    return res;
}
namespace sub1
{
    vector<int>E[N];
    unordered_map<int,bool>mp[N];
    int fa[N];
    int find(int x){return x==fa[x]?x:fa[x]=find(fa[x]);}
    void solve()
    {
        int x,y;
        For(i,2,n)
        {
            x=read<int>(),y=read<int>();
            E[x].push_back(y);
            E[y].push_back(x);
        }
        For(i,2,n)
        {
            x=read<int>(),y=read<int>();
            mp[x][y]=1,mp[y][x]=1;
        }
        For(i,1,n)fa[i]=i;
        For(u,1,n)for(int v:E[u])
        {
            if(mp[u][v])fa[find(v)]=find(u);
        }
        int cnt=0;
        For(i,1,n)if(i==fa[i])cnt++;
        write(power(m,cnt),'\n');
    }
}
int dp0,dp1;
void add(int &x,int y){x+=y;x-=(x>=mo?mo:0);}
namespace sub2
{
    int invm;
    vector<int>E[N];
    int dp[N][2];
    void DP(int u,int pre)
    {
        dp[u][0]=dp[u][1]=1;
        for(int v:E[u])if(v^pre)
        {
            DP(v,u);    
            dp0=dp1=0;
            add(dp0,1ll*dp[u][0]*dp[v][0]%mo*(invm-1)%mo);
            add(dp0,1ll*dp[u][0]*dp[v][1]%mo*n%mo);
            
            //add(dp1,1ll*dp[u][0]*dp[v][0]%mo);
            add(dp1,1ll*dp[u][1]*dp[v][0]%mo*(invm-1)%mo);
            add(dp1,1ll*dp[u][0]*dp[v][1]%mo*(invm-1)%mo);
            add(dp1,1ll*dp[u][1]*dp[v][1]%mo*n%mo);
        
            dp[u][0]=dp0;
            dp[u][1]=dp1;
        }

    }
    void solve()
    {
        int x,y;
        For(i,2,n)
        {
            x=read<int>(),y=read<int>();
            E[x].push_back(y);
            E[y].push_back(x);
        }
        invm=power(m,mo-2);
        DP(1,0);                
        write(1ll*dp[1][1]*power(n,mo-2)%mo*power(m,n)%mo,'\n');
    }
}
const int max_log=22;
int Mod(int x){return x>=mo?x-mo:x;}
namespace NTT
{
    int f[1<<max_log],g[1<<max_log];
    void init(int n)
    {
        int len=1;
        for(;len<=n;len<<=1);
        for(int i=2;i<=len;i<<=1)
        {
            f[i]=power(3,(mo-1)/i);
            g[i]=power(f[i],mo-2);
        }
    }
    int rev[1<<max_log]; 
    #define rg register
    void NTT(int *p,int len,int type)
    {
        static int u,v,tim,wn,x;
        For(i,1,len-1)if(i<rev[i])swap(p[i],p[rev[i]]);
        for(rg int i=2;i<=len;i<<=1)
        {
            tim=i>>1;
            wn=(type==1?f[i]:g[i]);
            assert(wn);
            for(rg int j=0;j<len;j+=i)
            {
                x=1;
                for(rg int k=0;k<tim;++k,x=1ll*x*wn%mo)
                {
                    u=p[j+k],v=1ll*p[j+k+tim]*x%mo;
                    p[j+k]=Mod(u+v);
                    p[j+k+tim]=Mod(u-v+mo);
                }
            }
        }
        if(type==-1)
        {
            int inv=power(len,mo-2);
            For(i,0,len-1)p[i]=1ll*p[i]*inv%mo;
        }
    }   
    int A[1<<max_log],B[1<<max_log],C[1<<max_log];
    void get_rev(int len,int cnt)
    {
        For(i,1,len-1)rev[i]=rev[i>>1]>>1|((i&1)<<(cnt-1));
    }
    void mul(int *a,int *b,int *c,int n1,int n2)
    {
        int len,cnt;
        for(len=1,cnt=0;len<=n1+n2;len<<=1)cnt++;
        get_rev(len,cnt);
        For(i,0,n1)A[i]=a[i];
        For(i,0,n2)B[i]=b[i];
        NTT(A,len,1),NTT(B,len,1);
        For(i,0,len)C[i]=1ll*A[i]*B[i]%mo;
        NTT(C,len,-1);
        For(i,0,n1+n2)c[i]=C[i];
        For(i,0,len)A[i]=B[i]=0;
    }
}
namespace INV
{
    int C[1<<max_log],D[1<<max_log];
    void get_inv(int *A,int *B,int len,int cnt)
    {
        if(len==1){B[0]=power(A[0],mo-2);return;}
        get_inv(A,B,len>>1,cnt-1);
        NTT::get_rev(len<<1,cnt+1);
        For(i,0,len-1)C[i]=A[i],D[i]=B[i];
        NTT::NTT(C,len<<1,1);
        NTT::NTT(D,len<<1,1);
        For(i,0,(len<<1)-1)C[i]=1ll*D[i]*D[i]%mo*C[i]%mo;
        NTT::NTT(C,len<<1,-1);
        For(i,len>>1,len-1)B[i]=Mod(Mod(B[i]+B[i])-C[i]+mo);
        For(i,0,(len<<1)-1)C[i]=D[i]=0;
    }
    int A[1<<max_log],B[1<<max_log];
    void inv(int *a,int *b,int n)
    {
        int len=1,cnt=0;
        for(;len<=n;len<<=1)cnt++;
        For(i,0,len<<1)A[i]=B[i]=C[i]=D[i]=0;
        For(i,0,n)A[i]=a[i];
        get_inv(A,B,len,cnt);
        For(i,0,n)b[i]=B[i];
    }
}
namespace LN
{
    int A[1<<max_log],B[1<<max_log],C[1<<max_log];
    void Direv(int *a,int *b,int n)
    {
        For(i,1,n)b[i-1]=1ll*a[i]*i%mo;
    }
    int inv[1<<max_log],ans[1<<max_log];
    void Inter(int *a,int *b,int n)
    {
        inv[1]=1;
        For(i,1,n)
        {
            if(i>1)inv[i]=1ll*inv[mo%i]*(mo-mo/i)%mo;
            b[i]=1ll*a[i-1]*inv[i]%mo;
        }
    }
    void get_ln(int *a,int *b,int n)
    {
        Direv(a,B,n);
        INV::inv(a,C,n);
        NTT::mul(B,C,A,n-1,n);
        Inter(A,ans,n);
        For(i,0,n)b[i]=ans[i];
        
    }
}
namespace EXP
{
    int A[1<<max_log],B[1<<max_log],C[1<<max_log],D[1<<max_log];
    void exp(int *a,int *b,int n)
    {
        For(i,0,n)A[i]=a[i];
        B[0]=1;
        int len=2,cnt=1;
        while(len<=(n<<1))
        {
            For(i,0,(len>>1)-1)C[i]=D[i]=B[i];
            LN::get_ln(C,D,len-1);
            For(i,0,len-1)D[i]=Mod(A[i]-D[i]+mo);
            D[0]=Mod(D[0]+1);
            NTT::get_rev(len<<1,cnt+1);
            NTT::NTT(C,len<<1,1);
            NTT::NTT(D,len<<1,1);
            For(i,0,(len<<1)-1)C[i]=1ll*C[i]*D[i]%mo;
            NTT::NTT(C,len<<1,-1);
            For(i,0,len-1)B[i]=C[i];
            For(i,0,len<<1)C[i]=D[i]=0;
            len<<=1,cnt++;
        }
        For(i,0,n)b[i]=B[i];
    }
}
namespace sub3
{
    int f[1<<max_log],g[1<<max_log];
    int mc[N],inv[N];
    void solve()
    {
        if(m==1){write(power(power(n,n-2),2),'\n');return;}
        NTT::init(n<<2);
        int invm=power(m,mo-2);
        mc[0]=inv[0]=1;
        For(i,1,n)mc[i]=1ll*mc[i-1]*i%mo;
        inv[n]=power(mc[n],mo-2);
        Fordown(i,n-1,1)inv[i]=1ll*inv[i+1]*(i+1)%mo;
        int tmp=power(invm-1,mo-2);
        int d=1ll*tmp*n%mo*n%mo;
        For(i,1,n)g[i]=1ll*d*power(i,i)%mo*inv[i]%mo;
        EXP::exp(g,f,n);
        int ans=f[n];
        //cerr<<ans<<endl;
        ans=1ll*ans*power(power(n,mo-2),4)%mo*mc[n]%mo*power(m,n)%mo*power(invm-1,n)%mo;
        write(ans,'\n');
    }
}
void work()
{
    if(ty==0)sub1::solve();
    else if(ty==1)sub2::solve();
    else if(ty==2)sub3::solve();
}
int main()
{
    file();
    input();
    work();
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/dengyixuan/p/10517014.html