CodeForces - Educational Round 64 F.Card Bag(乘法逆元+概率dp)

链接:

        Card Bag

题意:

        给n张牌,每张牌有个值ai,每次等概率从没选过的牌中选一张,当摸到连续两张相同牌时你赢,当摸到的牌比上一张牌小时,你输,当摸到的牌比上一张牌大时,继续游戏。没有牌之后你输。问赢的概率。输出概率对998244353取模。

思路:

        先考虑下输出,一个分数对998244353取模,肯定是逆元了,这个很容易想到,下文所有除法在代码中实现均为逆元。

        既然要最后输出一个概率,肯定是概率相关的算法。那到底是推组合数公式还是dp,要看下面分析了。

        (以上两行废话)

        考虑什么时候下一次可以立即赢?要满足第i次和第i+1次摸到的牌相等,还要保证前i-1次摸到的牌是单调增的,且每张都不同(如果相同已经赢了)。

        所以我们可以记录"当前摸到的最大的牌是几",然后下一次摸牌只能从>=它的牌中选,如果相等,赢,游戏结束,否则继续。

        考虑到n只有5000,空间复杂度允许O(n^2),可以开一个dp[i][j]代表第i轮摸完,最后一张牌(最大的)是j时的概率,则下一次可以立即赢的概率为:

        dp[i][j]\cdot\frac{cnt[j]-1}{n-i} 更新ans。

        dp的状态转移方程也很好想:

        dp[i+1][j]=\sum_{k=1}^{j-1} dp[i][k]\cdot \frac{cnt[j]}{n-i}

        含义就是,第i+1次摸到j的概率为所有摸了i次,最大数为k(k<j)的概率*第i+1次摸到j的概率。

        此题已经可写。但是,这题还能进一步优化,首先可以用前缀和来简化dp[i][j],一旦将dp[i][j]计算完成产生的ans之后,就dp[i][j] += dp[i][j-1],方便在下一轮直接O(1)得到sigma{dp[i][k]}。

        其次,这题还能优化空间复杂度。我们发现,状态转移只和相邻两个回合有关,所以可以优化为i&1的滚动数组dp[2][5010],然后,发现二维的滚动数组也没必要,其实可以优化为一维dp数组,只要在计算dp[j]时dp[j-1]是上一轮概率的前缀和即可,因此每轮计算概率要按下标递减,然后再正序更新概率前缀和。要注意剪枝掉无法访问的点,否则概率会出错。

代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
static const int maxn = 5050;
static const int mod = 998244353;
ll inv[maxn],sum[maxn];
ll quickmod(ll a,ll b){
    ll ret = 1;
    while(b){
        if(b&1)ret *= a,ret %= mod;
        b >>= 1;
        a = a*a % mod;
    }
    return ret;
}
ll dp[maxn],cnt[maxn],ans;

void redirect(){
    #ifdef LOCAL
        freopen("test.txt","r",stdin);
    #endif
}

int main(){
    redirect();
    int n;
    scanf("%d",&n);
    inv[0] = 1;
    for(int i = 1;i <= n;i++)inv[i] = quickmod(i,mod-2);
    for(int i = 1;i <= n;i++){
        int x;
        scanf("%d",&x);
        cnt[x]++;
    }
    for(int i = 1;i <= n;i++)sum[i] = sum[i-1] + cnt[i];
    fill(dp,dp+maxn,1);
    for(int i = 1;i < n;i++){
        for(int j = n;j >= i;j--){
            if(i-1 > sum[j-1]){
                dp[j] = 0;
                continue;
            }
            dp[j] = ((dp[j-1] * cnt[j]) % mod * inv[n-i+1]) % mod;
            if(cnt[j] > 1)
            ans += ((dp[j] * (cnt[j] - 1)) % mod * inv[n-i]) % mod,ans %= mod;
            dp[j] %= mod;
        }
        for(int j = i+1;j <= n;j++)dp[j] += dp[j-1],dp[j] %= mod;
    }
    printf("%I64d\n",ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/krypton12138/article/details/90147276