题解 hdu4624 Endless Spin

题目链接

题目大意:

有长度为\(n\)的区间,每次随机选择一段(左右端点都是整数)染黑,问期望多少次全部染黑。

\(n\leq 50\)

\(n\)个随机变量\(t_1,...,t_n\)\(t_i\)表示第一次覆盖到\(i\)的时间的期望。则我们要求的是\(\displaystyle\max_{i=1}^{n}(E(t_i))\)

考虑minmax容斥:

\[\max_{x\in s}(E(x))=\sum_{t\subseteq s}(-1)^{|t|+1}\min_{x\in t}(E(x))\]

这样我们就转化为对于每点集\(s\),求它第一次被覆盖到的期望操作次数(覆盖到其中任何一个点都算覆盖)。

如果我们知道了只操作一次的情况下它被覆盖概率\(p\),则期望操作次数就是\(\frac{1}{p}\)。(例如掷一次骰子掷到\(3\)的概率是\(\frac{1}{6}\),则期望掷\(6\)次可以第一次得到\(3\))。

这个还是不好求,我们转而求操作一次,\(s\)中的点一个都覆盖不到的概率\(p'\),则\(p=1-p',E(s)=\frac{1}{1-p'}\)

考虑如果暴力枚举一个子集\(s\)。则整个数列被\(s\)内的点划分成若干个区间,设长度分别为:\(l_1,l_2,...,l_k\)。则\(p'=\frac{\sum_{i=1}^{k}\frac{1}{2}l_i(l_i+1)}{\frac{1}{2}n(n+1)}\)。复杂度\(O(2^nn)\),无法承受。

扫描二维码关注公众号,回复: 9441301 查看本文章

考虑DP。设\(dp(i,j,k,0/1)\)表示考虑了前\(i\)个位置,最多能取\(j\)个区间\((j\leq \frac{1}{2}n(n+1))\),使得没有区间覆盖到点集内的点。上一次选的点集里的点距离\(i\)\(k\),点集的大小奇偶性为\(0/1\)。这样选出区间的方案数。

转移时考虑第\(i+1\)个位置是否加入点集:

  • 如果加入点集:\(f(i+1,j,0,1/0)+=f(i,j,k,0/1)\).

  • 如果不加入点集:\(f(i+1,j+k+1,k+1,0/1)+=f(i,j,k,0/1)\).

转移是\(O(1)\)的,所以DP的复杂度\(O(n^4)\)

统计答案时把所有\(j\)的情况加起来即可。即:\(ans(n)=\displaystyle\sum_{j=0}^{\frac{1}{2}n(n+1)-1}\frac{f(n,j,k,0/1)\times(-1)^{1/0}}{1-\frac{j}{\frac{1}{2}n(n+1)}}\).

备注:具体实现的时候把分数上下同时乘以\(\frac{1}{2}n(n+1)\)会更好写。式子上面的\(k\)表示所有\(k\)的情况的和。\((-1)\)的指数上的\(0/1\)之所以和状态里的\(0/1\)相反是因为minmax容斥的式子本来就是\((-1)^{|t|+1}\)

注意本题要使用高精度。

参考代码:

//problem:hdu4624
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

namespace Fread{
const int MAXN=1<<20;
char buf[MAXN],*S,*T;
inline char getchar(){
    if(S==T){
        T=(S=buf)+fread(buf,1,MAXN,stdin);
        if(S==T)return EOF;
    }
    return *S++;
}
}//namespace Fread
#ifdef ONLINE_JUDGE
    #define getchar Fread::getchar
#endif
inline int read(){
    int f=1,x=0;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
inline ll readll(){
    ll f=1,x=0;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
/*  ------  by:duyi  ------  */ // dysyn1314

namespace Bigdouble{
    const int K=50;
    typedef long long ll;
    struct db{ll zs,xs[K+5];db(){zs=0;memset(xs,0,sizeof(xs));}};
    db makedb(ll fz,ll fm){
        db res;
        res.zs=fz/fm,fz%=fm,fz*=10;
        for(int i=1;i<=K;++i)res.xs[i]=fz/fm,fz%=fm,fz*=10;
        return res;
    }
    db operator + (db a,db b){
        db res;ll jw=0;
        for(int i=K;i>=1;--i)res.xs[i]=a.xs[i]+b.xs[i]+jw,jw=res.xs[i]/10,res.xs[i]%=10;
        res.zs=a.zs+b.zs+jw;
        return res;
    }
    db operator - (db a,db b){
        db res;
        for(int i=K;i>=2;--i){
            if(a.xs[i]<b.xs[i])a.xs[i-1]--,a.xs[i]+=10;
            res.xs[i]=a.xs[i]-b.xs[i];
        }
        if(a.xs[1]<b.xs[1])a.zs--,a.xs[1]+=10;
        res.xs[1]=a.xs[1]-b.xs[1];
        res.zs=a.zs-b.zs;
        return res;
    }
    db operator * (db a,ll b){
        db res;
        ll jw=0;
        for(int i=K;i>=1;--i)res.xs[i]=a.xs[i]*b+jw,jw=res.xs[i]/10,res.xs[i]%=10;
        res.zs=a.zs*b+jw;
        return res;
    }
    void printdb(db a,int k=15){
        if(a.xs[k+1]>=5)a.xs[k]++;
        int t=k;
        while(a.xs[t]>=10){
            a.xs[t]-=10;
            if(t!=1)a.xs[--t]++;
            else{a.zs++;break;}
        }
        cout<<a.zs<<".";
        for(int i=1;i<=k;++i)cout<<a.xs[i];
    }
}
using namespace Bigdouble;
const int MAXN=55;
ll dp[MAXN][MAXN*MAXN][MAXN][2];
db ans[MAXN];
int main() {
    dp[0][0][0][0]=1;
    for(int i=0;i<50;++i){
        for(int j=0;j<=i*(i+1)/2;++j){
            for(int k=0;k<=i;++k){
                for(int t=0;t<=1;++t){
                    dp[i+1][j][0][t^1]+=dp[i][j][k][t];
                    dp[i+1][j+k+1][k+1][t]+=dp[i][j][k][t];
                }
            }
        }
    }
    for(int n=1;n<=50;++n){
        for(int j=0;j<n*(n+1)/2;++j){
            db tmp=makedb(n*(n+1)/2,n*(n+1)/2-j);
            ll sum=0;
            for(int k=0;k<=n;++k)sum+=dp[n][j][k][0];
            ans[n]=ans[n]-(tmp*sum);
            sum=0;
            for(int k=0;k<=n;++k)sum+=dp[n][j][k][1];
            ans[n]=ans[n]+(tmp*sum);
        }
    }
    //for(int n=1;n<=50;++n)printf("%d\n",n),printdb(ans[n]),puts("");return 0;
    int t=read();while(t--){
        int n=read();
        printdb(ans[n]);puts("");
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/dysyn1314/p/12371810.html
今日推荐