[SHOI2012]随机树[期望dp]

题意

初始 \(1\) 个节点,每次选定一个叶子节点并加入两个儿子直到叶子总数为 \(n\),问叶子节点深度和的平均值的期望以及最大叶子深度的期望。

\(n\leq 100\) .

分析

  • 对于第一问,根据答案定义状态 \(f_i\) 表示有 \(i\) 个叶子节点的深度和平均值的期望。

  • 考虑对于之前的每一棵树对期望的贡献,记其发生的概率为 \(p\) ,深度和为 \(w\) ,有 \(i-1\) 个叶子节点。贡献为 \(p*\frac{w}{i-1}\) 。现在要多选定一个叶子节点有 \(i-1\) 种方案,总贡献可以写成:
    \[p*\frac{1}{i-1}*\frac{(i-1)w+w+2(i-1)}{i} =\frac{ip\frac{w}{i-1}+2p}{i}\]
    也就有\(f_i=f_{i-1}+\frac{2}{i}\)

  • 对于第二问,定义状态 \(g_{i,j}\) 表示子树内有 \(i\) 个叶子,最大深度为 \(j\) 的概率。

  • 再定义 \(p_{i,j}\) 表示 \(i\) 个叶子节点有 \(j\) 个在左子树的概率,转移:
    \[f_{i,j}=\sum_{l=1}^{i-1}p_{i,l}*\sum_x\sum_y[\max(x,y)+1=j]f_{l,x}*f_{i-l,y}\]

  • \(p\) 的递推直接枚举最后一个叶子是接在左子树还是右子树即可。

  • 可以前缀和优化,总时间复杂度为 \(O(n^3)\).

代码

#include<bits/stdc++.h>
using namespace std;
#define go(u) for(int i=head[u],v=e[i].to;i;i=e[i].last,v=e[i].to)
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define pb push_back
typedef long long LL;
inline int gi(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-48;ch=getchar();}
    return x*f;
}
template<typename T>inline bool Max(T &a,T b){return a<b?a=b,1:0;}
template<typename T>inline bool Min(T &a,T b){return b<a?a=b,1:0;}
const int N=104;
int type,n;
namespace task1{
    double f[N];
    void solve(){
        f[1]=0;
        rep(i,2,n) f[i]=f[i-1]+2.0/i;
        printf("%.6lf\n",f[n]);
    }
}
namespace task2{
    double f[N][N],s[N][N],p[N][N];
    void solve(){
        p[2][1]=1;
        rep(i,3,n)rep(j,1,i-1)
            p[i][j]=( p[i-1][j-1]*1.0*(j-1)/(i-1) + p[i-1][j]*1.0*(i-1-j)/(i-1));
        
        f[1][0]=1;rep(j,0,n) s[1][j]=(j?s[1][j-1]:0)+f[1][j];
        f[2][1]=1;rep(j,1,n) s[2][j]=s[2][j-1]+f[2][j];
        
        rep(i,3,n){
            rep(j,1,i-1){
                rep(l,1,i-1)
                f[i][j]+=p[i][l]*((j-1>=0?s[l][j-1]:0)*(j-1>=0?f[i-l][j-1]:0)+(j-1>=0?f[l][j-1]:0)*(j-2>=0?s[i-l][j-2]:0));
                s[i][j]=s[i][j-1]+f[i][j];
            }
            fill(s[i]+i,s[i]+1+n,s[i][i-1]);
        }
        double ans=0;
        for(int j=0;j<=n;++j) ans+=f[n][j]*j;
        printf("%.6lf\n",ans);
    }
}
int main(){
    type=gi(),n=gi();
    if(type==1) task1::solve();
    else task2::solve();
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/yqgAKIOI/p/9891146.html
今日推荐