[ZJOI2012]波浪

Description:

L = | P2 – P1 | + | P3 – P2 | + … + | PN – PN-1 |
给你一个N和M,问:随机一个1…N的排列,它的波动强度(L)不小于M的概率有多大?

Hint:

\(n \le 100\)

Solution:

传说中的神仙dp,难在如何转化问题

绝对值很不好搞,我们考虑按从小到大顺序依次插入这些数

\(f[i][j][k][l]\)表示插入第i个数后,分成了j个连续段,强度为k,有l个边界已确定的方案数

然后就很套路了,详见代码

这种dp真的需要一些大胆的想法,要敢想敢打

#include <map>
#include <set>
#include <stack>
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#define ls p<<1 
#define rs p<<1|1
using namespace std;
typedef long long ll;

inline int read() {
    char c=getchar(); int x=0,f=1;
    while(c>'9'||c<'0') {if(c=='-') f=-1;c=getchar();}
    while(c<='9'&&c>='0') {x=(x<<3)+(x<<1)+(c&15);c=getchar();}
    return x*f;
}
inline int chkmax(int &x,int y) {if(x<y) x=y;}
inline int chkmin(int &x,int y) {if(x>y) x=y;}

namespace db {long double dp[2][110][10010][3];}
namespace flt {__float128 dp[2][110][10010][3];}
int n,m,k;
template <class T>

void print(T ans) {
    cout<<"0.";
    ans*=10;
    for(int i=1;i<=k;++i) {
        cout<<(int ) (ans+(k==i)*0.5);
        ans=(ans-(int )ans)*10;
    }
}

template <class T>

void solve(T dp[][110][10010][3]) {
    T ans=0; int t=0;
    dp[0][0][5000][0]=1; //因为状态可能出现负数,故预处理出值域,把0当成5000
    for(int i=1;i<=n;++i) {
        t^=1; memset(dp[t],0,sizeof(dp[t]));
        for(int j=0;j<=min(i-1,m);++j) {
            for(int k=0;k<=10000;++k) {
                for(int l=0;l<=2;++l) {
                    if(!dp[t^1][j][k][l]) continue ;
                    if(k-2*i>=0) dp[t][j+1][k-2*i][l]+=dp[t^1][j][k][l]*(j+1-l);
                    if(j) dp[t][j][k][l]+=dp[t^1][j][k][l]*(j*2-l);
                    if(j>=2&&k+i*2<=10000)
                        dp[t][j-1][k+2*i][l]+=dp[t^1][j][k][l]*(j-1);
                    if(l<2) {
                        if(k-i>=0) dp[t][j+1][k-i][l+1]+=dp[t^1][j][k][l]*(2-l);
                        if(j&&k+i<=10000) 
                            dp[t][j][k+i][l+1]+=dp[t^1][j][k][l]*(2-l);
                    }
                } 
            }
        }
    }
    for(int i=m;i<=5000;++i) 
        ans+=dp[t][1][5000+i][2];
    for(int i=1;i<=n;++i) 
        ans/=i;
    print(ans); 
}

int main()
{
    n=read(); m=read(); k=read();
    if(k<=8) solve(db::dp);
    else solve(flt::dp);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/list1/p/10505479.html