2019雅礼集训 D7T1 inverse [概率/期望,DP]

题目描述:

样例:

input1:
3 1
1 2 3

output1:
833333340

input2:
5 10
2 4 1 3 5

output2:
62258360

数据范围与约定:

标签:概率/期望,DP


概率/期望的常用套路:将许许多多个元素单独考虑,以达到解决问题的目的。

这里发现不可能整个序列一起考虑,于是枚举任意两个位置,计算出k次翻转之后左边大于右边的概率,再加起来就好了。

于是我们有了一个非常暴力的DP:

\(dp(i,j,k)\) 表示k次翻转之后\(i\)位置大于\(j\)位置的概率。为了方便我们强行令\(i>j\)

然后,对于上一次发生的几种可能的区间翻转分别考虑:

一、\(r<i\; ||\; l>j\; ||\; i<l \leq r<j\)

我们发现这种情况时\(i\)\(j\)没有受到影响,于是直接继承上一个时间的状态即可。

贡献:\(({{i-1}\choose 2}+{{n-j}\choose 2}+{{j-i-1}\choose 2})dp(i,j,k-1)\)

二、\(l \leq i \leq r < j\)

此时\(i\)在上一次的位置是\(l+r-i\)\(j\)没有发生变化。

贡献:\(\sum_{l=1}^i \sum_{r=i}^{j-1} dp(l+r-i,j,k-1)\)

三、\(i<l \leq j \leq r​\)

此时\(j\)在上一次的位置是\(l+r-j\)\(i\)没有变化。

贡献:\(\sum_{l=i+1}^j \sum_{r=j}^n dp(i,l+r-j,k-1)​\)

四、\(l \leq i < j \leq r\)

此时\(i\)\(j\)的位置都发生了变化,上一次分别是\(l+r-i\)\(l+r-j\)

贡献:

\(\sum_{l=1}^i \sum_{r=j}^n (1-dp(l+r-j,l+r-i,k-1))=i(n-j+1)-\sum_{l=1}^i \sum_{r=j}^n dp(l+r-j,l+r-i,k-1)\)

将上面四种贡献加起来,再除个\(\frac{n(n+1)}{2}​\)就好了。

复杂度\(O(n^4k)\),显然要挂。

考虑优化:发现转移方程可以用前缀和优化一下,这里以第二种情况为例:
\[ \begin{align*} &令 S_1(n,j,k)=\sum_{i=1}^n dp(i,j,k),\; S_2(n,j,k)=\sum_{i=1}^n S_2(i,j,k) ,那么: \\ &\sum_{l=1}^i \sum_{r=i}^{j-1} dp(l+r-i,j,k-1) \\ =&\sum_{l=1}^i(S_1(l+j-i-1)-S_1(l-1,j,k-1))\\ =&S_2(j-1,j,k-1)-S_2(j-i-1,j,k-1)-S_2(i-1,j,k-1) \end{align*} \]
可以\(O(1)​\)求啦!

同理可得,
\[ \begin{align*} &令S_3(i,n,k)=\sum_{j=1}^n dp(i,j,k),S_4(i,n,k)=\sum_{j=1}^n S_3(i,j,k)\\ &第三种情况:S_4(i,n,k-1)-S_4(i,n+i-j,k-1)-S_4(i,j-1,k-1)\\ &令g(i,j,k)=dp(i,i+j,k-1),S_5(n,j,k)=\sum_{i=1}^n g(i,j,k),S_6(n,j,k)=\sum_{i=1}^n S_5(i,j,k)\\ &第四种情况:\\ &i(n-j+1)-S_6(n+i-j,j-i,k-1)+S_6(n-j,j-i,k-1)+S_6(i-1,j-i,k-1) \end{align*} \]
大功告成!

最后,由于出题人卡常(1s),你可能需要一些常数优化才能过(滑稽)

代码(上方的\(S_2,S_4,S_6在代码中分别为S_1,S_2,S_3\)):

#include<cstdio>

namespace my_std{
    const int mod=1e9+7;
    #define rep(i,x,y) for (register int i=(x);i<=(y);++i)
    #define drep(i,x,y) for (register int i=(x);i>=(y);--i)
    #define sz 510
    typedef long long ll;
    template<typename T>
    inline void read(T& t)
    {
        t=0;char ch=getchar();
        while(ch>'9'||ch<'0') ch=getchar();
        while(ch<='9'&&ch>='0') t=t*10+ch-48,ch=getchar();
    }
    void file(){freopen("a.txt","r",stdin);}
}
using namespace my_std;

inline ll ksm(ll x,int y)
{
    ll ret=1;
    for (;y;y>>=1,x=x*x%mod) if (y&1) ret=ret*x%mod;
    return ret;
}

int n,K;
int a[sz];

ll dp[51][sz][sz];

ll s1[51][sz][sz],s2[51][sz][sz],s3[51][sz][sz];

ll C[sz];

inline void M(ll &x){if (x>=mod) x-=mod;}

int main()
{
    file();
    read(n),read(K);
    rep(i,1,n) read(a[i]);
    rep(i,1,n)
        rep(j,i+1,n)
            dp[0][i][j]=(a[i]>a[j]);
    rep(i,1,n) C[i]=C[i-1]+i;
    ll INV=ksm(C[n],mod-2);
    rep(t,0,K-1)
    {
        rep(j,1,n)
        {
            rep(i,1,j-1) s1[t][i][j]=s1[t][i-1][j]+dp[t][i][j];
            rep(i,1,j-1) M(s1[t][i][j]+=s1[t][i-1][j]);
        }
        rep(i,1,n)
        {
            rep(j,i+1,n) s2[t][i][j]=s2[t][i][j-1]+dp[t][i][j];
            rep(j,i+1,n) M(s2[t][i][j]+=s2[t][i][j-1]);
        }
        rep(j,0,n-1)
        {
            rep(i,1,n-j) s3[t][i][j]=s3[t][i-1][j]+dp[t][i][i+j];
            rep(i,1,n-j) M(s3[t][i][j]+=s3[t][i-1][j]);
        }
        
        rep(i,1,n) rep(j,i+1,n)
        {
            ll &S=dp[t+1][i][j];
            
            S=(C[i-1]+C[n-j]+C[j-i-1])*dp[t][i][j];
            
            S+=s1[t][j-1][j]-s1[t][j-i-1][j]-s1[t][i-1][j];
            
            S+=s2[t][i][n]-s2[t][i][n+i-j]-s2[t][i][j-1]+s2[t][i][i-1];
            
            S-=s3[t][n+i-j][j-i]-s3[t][n-j][j-i]-s3[t][i-1][j-i];
            
            S+=i*(n-j+1);
            
            S=(S%mod+mod)%mod*INV%mod;

//          dp[t+1][i][j]=(C(i-1)+C(n-j)+C(j-i-1))*dp[t][i][j];
//          rep(l,1,i) rep(r,i,j-1) dp[t+1][i][j]+=dp[t][l+r-i][j];
//          rep(l,i+1,j) rep(r,j,n) dp[t+1][i][j]+=dp[t][i][l+r-j];
//          rep(l,1,i) rep(r,j,n) dp[t+1][i][j]+=1-dp[t][l+r-j][l+r-i];
//          dp[t+1][i][j]/=C[n];
//          此处为原DP方程(小数形式)
        }
    }
    
    ll ans=0;
    rep(i,1,n)
        rep(j,i+1,n)
            M(ans+=dp[K][i][j]);
            
    ans=(ans%mod+mod)%mod;
    
    printf("%lld",ans);
}

猜你喜欢

转载自www.cnblogs.com/p-b-p-b/p/10260907.html