[学习笔记] KthMax-Min - Min-Max容斥

有n个数字,每单位时间会出现一个数字,第i个数字有 p i m 的概率出现,并且 i = 1 n p i = m ,求出现了k个数字的时间的期望。
n 1000 , m 10000 , n k 10
这个玄学数据范围可海星,可以做到O(nm(n-k))或者O(nmk)。
首先考虑这个问题,询问的k,等价于询问集合的第n-k+1大。
现在考虑转化后的问题,也就是要找到一个函数F,使得下式成立:
T S m i n ( T ) F ( | T | ) = k ,这里的k已经是刚刚的n-k+1了。
min(T)在这个题目里面就是m/Sum(T)。
考虑一般Min-Max容斥的证明,对于第x大的元素,其被计算系数应该为[x==k]:
[ x = k ] = T = 0 n x ( n x T ) F ( | T | + 1 )
我们希望构造一个函数使得这个式子成立,也就是当x< k的时候不会被计算,当x>k的时候会被容斥掉,那么我们取:
F ( x ) = ( 1 ) x k ( x 1 k 1 )
带回原式:
[ x = k ] = T = 0 n x ( n x T ) ( 1 ) T + 1 k ( T k 1 )
这样,当n-x< k-1时,T< k-1,后面那个组合数直接是0;否则,当n-x>=k的时候,相当于是先从n-x个元素中选出T的,然后从其中选出k-1个;换句话说就是先从n-x个中选出k-1个,然后剩下的n-x-k+1个元素再任意选择,并且根据这些元素的个数决定系数的正负,显然后半部分只有当n-x-k+1=0的时候系数是1,此时两个组合数也是1,这样就得证了。
我们带回最最开始的式子:
T S m S u m ( T ) ( 1 ) | T | k ( | T | 1 k 1 )
注意到k比较小,我们设dp(i, j, t)表示前i个数字,选出来的数字之和是t,前面的系数之和在k=j的时候是多少,转移显然。
如果k比较大,那么因为|T|>=k所以直接设dp(i, j, t)表示前i个数字删去j个,剩下的数字之和是多少。
总之复杂度如上文。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<assert.h>
#define N 1010
#define M 10010
#define NMK 15
#define mod 998244353
#define lint long long
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
int dp[2][NMK][M],p[N],fac[M],facinv[M],inv[M];
inline int sol(lint x,int s) { return x%=mod,((s&1)?(mod-x)%mod:x); }
inline int fast_pow(int x,int k,int ans=1)
{   for(;k;k>>=1,x=(lint)x*x%mod) (k&1)?ans=(lint)ans*x%mod:0;return ans;   }
inline int prelude(int n)
{
    for(int i=fac[0]=1;i<=n;i++) fac[i]=(lint)fac[i-1]*i%mod;
    facinv[n]=fast_pow(fac[n],mod-2);
    for(int i=n-1;i>=0;i--) facinv[i]=facinv[i+1]*(i+1ll)%mod;
    for(int i=1;i<=n;i++) inv[i]=(lint)fac[i-1]*facinv[i]%mod;
    return 0;
}
int main()
{
    int n,k,m,ans=0,now,pre;scanf("%d%d%d",&n,&k,&m);
    rep(i,1,n) scanf("%d",&p[i]);prelude(max(n,m));
    pre=0,now=1,k=n-k+1;rep(i,1,k) dp[pre][i][0]=-1;
    for(int i=1;i<=n;i++,swap(now,pre))
    {
        rep(j,1,k) memcpy(dp[now][j],dp[pre][j],sizeof(int)*(m+1));
        rep(j,1,k) rep(t,p[i],m)
            (dp[now][j][t]+=mod-dp[pre][j][t-p[i]])%=mod,
            (dp[now][j][t]+=dp[pre][j-1][t-p[i]])%=mod;
    }
    swap(now,pre);
    rep(i,1,m) if(dp[now][k][i]) (ans+=(lint)dp[now][k][i]*inv[i]%mod)%=mod;
    return !printf("%lld\n",(lint)ans*m%mod);
}

猜你喜欢

转载自blog.csdn.net/Mys_C_K/article/details/81284563