#loj3089 [BJOI2019]奥术神杖

卡精度好题

最关键的一步是几何平均数的\(ln\)等于所有数字取\(ln\)后的算术平均值

那么现在就变成了一个很裸的01分数规划问题,一个通用的思路就是二分答案

现在来考虑二分答案的底层怎么写

把所有串拉出来造ac自动机,那么ac自动机上一个点的权值就是

fail树上这个点到祖先的树链上的字符串的权值之和

那么接下来设\(f(i,j)\)表示决策到了第\(i\)个字符,走到自动机节点\(j\)的最大收益大力dp即可

由于我们不希望均值是0,因此额外记录下有没有匹配上模式串即可

二分完了之后把mid调成l重新跑一遍dp,不然你会跑出无解的情况导致输出错误的答案

有个函数叫log2,精度比log高,这样你就可以少二分几次了

// luogu-judger-enable-o2
// luogu-judger-enable-o2
// luogu-judger-enable-o2
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<queue>
using namespace std;const int N=4000;typedef long long ll;
typedef long double db;db mid;//const db eps=1e-8;
const db low_inf=-1e9;
int v[N<<1];int x[N<<1];int ct;int al[N];
db we[N];db sum[N];int num[N];int snum[N];
inline void add(int u,int V){
//printf("add %d %d\n",u,V);
v[++ct]=V;x[ct]=al[u];al[u]=ct;}
inline void pdfs(int u)
{
    snum[u]+=num[u];
    for(int i=al[u];i;i=x[i])
        snum[v[i]]=snum[u],pdfs(v[i]);
}
inline void dfs(int u)
{
    sum[u]+=we[u]-mid*num[u];
    for(int i=al[u];i;i=x[i])
        sum[v[i]]=sum[u],dfs(v[i]);
}
struct trie
{
    int mp[N][12];int cnt;int fil[N];
    inline int ins(int p,int c)
    {return mp[p][c]=(mp[p][c])?mp[p][c]:++cnt;}
    inline void build()
    {
        queue <int> q;
        for(int i=1;i<=10;i++)
            if(mp[1][i])fil[mp[1][i]]=1,q.push(mp[1][i]);
            else mp[1][i]=1;
        while(!q.empty())
        {
            int nw=q.front();q.pop();
            for(int i=1;i<=10;i++)
                if(mp[nw][i])fil[mp[nw][i]]=mp[fil[nw]][i],q.push(mp[nw][i]);
                else mp[nw][i]=mp[fil[nw]][i];
        }
        for(int i=1;i<=cnt;i++)
            if(fil[i])add(fil[i],i);
    }
}tr;
struct data{int c;int lst;int pval;}fr[N][N];db dp[N][N];
char mde[N];int n;int m;char smde[N];int op[N];int hd;
inline void trans(int i,int j,int k)
{
    int tw=tr.mp[j][k];
    db tval=dp[i][j]+sum[tw];
    if(tval>=dp[i+1][tw])
    {
        dp[i+1][tw]=tval;
        fr[i+1][tw]=(data){k,j,fr[i][j].pval+snum[tw]};
    }
}
inline void pritans()
{
    db curmx=-0x3f3f3f3f;int st=-1;
    for(int i=1;i<=tr.cnt;i++)
        if(curmx<dp[n][i])
            curmx=dp[n][i],st=i;
    hd=0;
    for(int i=n;i>=1;i--)
        op[++hd]=fr[i][st].c,st=fr[i][st].lst;
    for(int i=n;i>=1;i--)
        printf("%d",op[i]-1);
    printf("\n");
}
inline bool jud()
{
    //for(int i=1;i<=tr.cnt;i++)
    //  printf("%.3Lf ",sum[i]);printf("\n");
    sum[1]=0;
    dfs(1);
    for(int i=0;i<=n;i++)
        for(int j=1;j<=tr.cnt;j++)
            dp[i][j]=-0x3f3f3f3f;
    dp[0][1]=0;
    for(int i=0;i<n;i++)
        for(int j=1;j<=tr.cnt;j++)
        {
            if(dp[i][j]<low_inf)continue;
            if(mde[i+1]=='.')
                for(int k=1;k<=10;k++)
                    trans(i,j,k);
            else 
                trans(i,j,mde[i+1]-'0'+1);
        }
    db mx=-0x3f3f3f3f;
    for(int i=1;i<=tr.cnt;i++)
        if(fr[n][i].pval)mx=max(mx,dp[n][i]);
    //printf("mx=%.10Lf\n",mx);
   // pritans();
    return mx>=0;
}

int main()
{
    //printf("%.10lf\n",exp(log(10)));
    scanf("%d%d",&n,&m);
    scanf("%s",mde+1);
    tr.cnt=1;
    for(int i=1,tmp;i<=m;i++)
    {
        scanf("%s",smde+1);
        int p=1;
        for(int j=1;smde[j]!='\0';j++)
            p=tr.ins(p,smde[j]-'0'+1);
        scanf("%d",&tmp);
        we[p]+=log2(tmp);num[p]++;
      //  printf("%.10Lf\n",we[p]);
    }
    tr.build();
    pdfs(1);
    db l=0;db r=log2(1e9);
    for(int i=1;i<=18;i++)
    {
    //  printf("%.10Lf %.10Lf\n",l,r);
        mid=(l+r)/2;
        if(jud())l=mid;else r=mid;
    }
    mid=l;
    jud();
    pritans();
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/sweetphoenix/p/10786229.html