2020杭电多校第五场 Set2(DP,组合数学)

Problem Description
You are given a set S={1…n}. You have to do the following operations until there are no more than k elements left in the S:

Firstly, delete the smallest element of S, and then randomly delete another k elements one by one from the elements left in S in equal probability.

Note that the order of another deleted k elements matters. That is to say, you delete p after q or delete q after p, which are different ways.

For each i∈[1,n], determine the probability of i being left in the S.

It can be shown that the answers can be represented by PQ, where P and Q are coprime integers, and print the value of P×Q−1 mod 998244353.

Input
The first line contains the only integer T(T∈[1,40]) denoting the number of test cases.

For each test case:

The first line contains two integers n and k.

It guarantees that: n∈[1,5000], ∑n∈[1,3×104], k∈[1,5000].

Output
For each test case, you should output n integers, the i-th of them means the probability of i being left in the S.

Sample Input
1
5 2

Sample Output
0 499122177 499122177 499122177 499122177

Source
2020 Multi-University Training Contest 5

题意:
每次删除一个最小的数,再随机删除k个数,求每个数剩下的概率。

思路:
大概思路就是将随机删除k个数合在一起,则可以定义 d p [ i ] [ j ] dp[i][j] ,删除了前i个数,还可以随机删除 j j 次数的操作次数。

那么 d p [ i + 1 ] [ j + k ] + = d p [ i ] [ j ] dp[i+1][j+k]+=dp[i][j] ,代表删除了最小的数。
d p [ i + 1 ] [ j 1 ] + = d p [ i ] [ j ] dp[i+1][j-1]+=dp[i][j] ,代表随机删数删掉了 i i ,因为有 j j 次操作,每次操作都可以删 i i ,所以要乘以 i i

得到操作次数,再算剩下的数的删除次序,就知道每个数剩下的方案数了。

下面是题解
在这里插入图片描述
注意题解中有笔误,应该是 j = n i + 1 r j=n-i+1-r

这题是set1的升级版,set1删完了最小的数,每次只随机删除一个数字,这题随机删除k个。所以本题不能直接枚举剩下的数再组合数学。

感觉题解就十分巧妙,随机删 k k 个太难考虑了,我就合并变成还剩下 j j 个随机删除操作。然后通过DP算出删除了前 i i 个操作方案数,这样剩下的部分就成了有x个数,要删y个数,那就是组合数学问题了。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
#include <queue>
#include <iostream>
#include <map>
#include <string>

using namespace std;

typedef long long ll;

const int mod = 998244353;
const int maxn = 5000 + 7;

ll dp[maxn][maxn],f[maxn],cnt[maxn],ans[maxn],sum[maxn];
ll fac[maxn],inv[maxn];

ll qpow(ll x,ll n) {
    ll res = 1;
    while(n) {
        if(n & 1) res = res * x % mod;
        x = x * x % mod;
        n >>= 1;
    }
    return res;
}

ll C(ll n,ll m) {
    if(m > n || m < 0) return 0;
    return fac[n] * inv[m] % mod * inv[n - m] % mod;
}

void init() {
    fac[0] = inv[0] = 1;
    for(int i = 1;i < maxn;i++) {
        fac[i] = fac[i - 1] * i % mod;
        inv[i] = qpow(fac[i],mod - 2);
    }
}

int main() {
    int T;scanf("%d",&T);
    init();
    while(T--) {
        int n,k;scanf("%d%d",&n,&k);
        if(n < k + 1) {
            for(int i = 1;i <= n;i++) {
                printf("1");
                if(i != n) printf(" ");
            }
            printf("\n");
            continue;
        }
        if(n % (k + 1) == 0) {
            for(int i = 1;i <= n;i++) {
                printf("0");
                if(i != n) printf(" ");
            }
            printf("\n");
            continue;
        }
        for(int i = 0;i <= n + 1;i++) {
            for(int j = 0;j <= n + 1;j++) {
                dp[i][j] = 0;
            }
            f[i] = sum[i] = cnt[i] = ans[i] = 0;
        }
        dp[0][0] = 1;
        for(int i = 0;i <= n;i++) {
            for(int j = 0;j <= n;j++) {
                if(j + k <= n - i - 1) {
                    dp[i + 1][j + k] = (dp[i + 1][j + k] + dp[i][j]) % mod;
                }
                if(j > 0) {
                    dp[i + 1][j - 1] = (dp[i + 1][j - 1] + dp[i][j] * j % mod) % mod;
                }
            }
        }
        
        ll SUM = 0;//总方案
        int r = n % (k + 1);//最后剩下的数字
        for(int i = 1;i <= n;i++) { //前i个数都被删掉了
            int j = n - i - r + 1; //可以进行的删除次数
            f[i] += dp[i - 1][j] * fac[j] % mod * C(n - i,j) % mod;
            f[i] %= mod;
            
            cnt[i] += dp[i - 1][j] * fac[j] % mod * C(n - i - 1,j) % mod;
            cnt[i] %= mod;
            
            sum[i] = (sum[i - 1] + cnt[i]) % mod;
            SUM = (SUM + f[i]) % mod;
            ans[i] = (f[i] + sum[i - 1]) % mod;
        }
        printf("FUCK %lld\n",dp[2][1]);
        for(int i = 1;i <= n;i++) {
            printf("%lld ",f[i]);
            if(i == n) printf("\n");
        }
        SUM = qpow(SUM,mod - 2);
        for(int i = 1;i <= n;i++) {
            printf("%lld",ans[i] * SUM % mod);
            if(i != n) printf(" ");
            else printf("\n");
        }

    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/tomjobs/article/details/107822842