CF739E Gosha is hunting 【WQS二分 + 期望】

题目链接

CF739E

题解

抓住个数的期望即为概率之和
使用\(A\)的期望为\(p[i]\)
使用\(B\)的期望为\(u[i]\)
都使用的期望为\(p[i] + u[i] - u[i]p[i]\)
当然是用越多越好

但是他很烦地给了个上限,我们就需要作出选择了
有一个很明显的\(O(n^3)\)\(dp\),显然过不了

但我们有一个很好的\(WQS\)二分
我们非常想去掉这个上限
那就去掉吧,但是每用一次都要付出一个代价
我们二分这个代价,当使用次数恰好为为\(a\)\(b\)时就是答案
再加回付出的代价即可
非常巧妙地变成了\(O(n\log^2n)\)

这种二分技巧非常棒
当我们求的东西有一个限制个数时,可以通过设置代价去掉上限

//Mychael
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<map>
#define LL long long int
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define cls(s,v) memset(s,v,sizeof(s))
#define mp(a,b) make_pair<int,int>(a,b)
#define cp pair<int,int>
#define eps 1e-9
using namespace std;
const int maxn = 2005,maxm = 100005,INF = 0x3f3f3f3f;
inline int read(){
    int out = 0,flag = 1; char c = getchar();
    while (c < 48 || c > 57){if (c == '-') flag = 0; c = getchar();}
    while (c >= 48 && c <= 57){out = (out << 1) + (out << 3) + c - 48; c = getchar();}
    return flag ? out : -out;
}
int n,a,b,cnta,cntb;
double p[maxn],u[maxn],A,B,ans;
int work(double cost){
    A = cost; cnta = cntb = 0; ans = 0;
    int sol; double val;
    REP(i,n){
        val = 0; sol = 0;
        if (p[i] - A > val) sol = 1,val = p[i] - A;
        if (u[i] - B > val) sol = 2,val = u[i] - B;
        if (p[i] + u[i] - u[i] * p[i] - A - B > val)
            sol = 3,val = p[i] + u[i] - u[i] * p[i] - A - B;
        if (sol == 1 || sol == 3) cnta++;
        if (sol == 2 || sol == 3) cntb++;
        ans += val;
    }
    return cnta;
}
int check(double cost){
    B = cost;
    double l = 0,r = 1.0,mid;
    while (r - l > eps){
        mid = (l + r) / 2.0;
        if (work(mid) <= a) r = mid;
        else l = mid;
    }
    work(r);
    A = l;
    return cntb;
}
int main(){
    n = read(); a = read(); b = read();
    REP(i,n) scanf("%lf",&p[i]);
    REP(i,n) scanf("%lf",&u[i]);
    double l = 0,r = 1.0,mid;
    while (r - l > eps){
        mid = (l + r) / 2.0;
        if (check(mid) <= b) r = mid;
        else l = mid;
    }
    check(r);
    printf("%.8lf",ans + a * A + b * B);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Mychael/p/9264916.html