BZOJ4820 [SDOI2017]硬币游戏

BZOJ4820 [SDOI2017]硬币游戏

题面:BZOJ

解析

考虑把所有没有到达结束状态的字符串看做一类字符串\(N\),把以字符串\(i\)作为结束的一类字符串\(i\)。现在假设猜测了两个字符串\(A=TTH\)\(B=HTT\),不难发现可以列一个方程出来:

\[P(NTTH)=P(A)\]

遗憾的是,这个方程是错的,因为在\(N\)向后匹配的过程中,有可能先一步匹配其他字符串,比如\(N\)\(HT\)结尾,只要加上一个\(T\),就以\(B\)结束了,我们试着把它写进方程里面,有:

\[P(NTTH)=P(A)+P(BTH)+P(BH)\]

写作带系数的形式:

\[0.125P(N)=P(A)+0.25P(B)+0.5P(B)\]

不难发现只要一个字符串\(B\)的后缀为字符串\(A\)的前缀,在\(A\)对应的方程里就有一个贡献,这个不难,随便用个字符串算法就能处理。

那么我们可以对\(n\)个字符串分别列出方程,再加上所有\(P(i)\)概率之和为1,一共\(n+1\)个方程,\(n+1\)个变量,高斯消元即可。

代码


#include<cmath>
#include<queue>
#include<cstdio>
#define N 305

using namespace std;

int n,m,s[N][N]; char str[N];

inline int In(){
    char c=getchar(); int x=0,ft=1;
    for(;c<'0'||c>'9';c=getchar()) if(c=='-') ft=-1;
    for(;c>='0'&&c<='9';c=getchar()) x=x*10+c-'0';
    return x*ft;
}

int h[N*N],e_tot=0;
struct E{ int to,nex; }e[N*N];

inline void add(int u,int v){
    e[++e_tot]=(E){v,h[u]}; h[u]=e_tot;
}

int rt=0,T_tot=0,ch[N*N][2],fail[N*N],d[N*N];

inline void Insert(int i){
    int u=rt;
    for(int j=0;j<m;++j){
        if(!ch[u][s[i][j]]) ch[u][s[i][j]]=++T_tot;
        u=ch[u][s[i][j]]; add(u,i);
    }
}

inline void Get_fail(int m){
    queue<int> Q; fail[rt]=rt;
    for(int i=0,v;i<=m;++i){
        v=ch[rt][i]; if(!v) continue;
        Q.push(v); fail[v]=rt; d[v]=1;
    }
    while(!Q.empty()){
        int u=Q.front(); Q.pop();
        for(int i=0,p,v;i<=m;++i){
            v=ch[u][i]; if(!v) continue;
            Q.push(v); p=fail[u];
            while(p&&!ch[p][i]) p=fail[p];
            fail[v]=ch[p][i]; d[v]=d[u]+1;
        }
    }
}

double a[N][N],p[N],ans[N];

inline void Get_v(int u,int B){
    if(!u) return;
    for(int i=h[u];i;i=e[i].nex) a[e[i].to][B]-=p[m-d[u]];
    Get_v(fail[u],B);
}

inline void Get_a(){
    for(int i=1;i<=n;++i) a[i][n+1]=p[m];
    for(int i=1,u;i<=n;++i){
        u=rt;
        for(int j=0;j<m;++j){
            while(u&&!ch[u][s[i][j]]) u=fail[u];
            u=ch[u][s[i][j]];
        }
        Get_v(u,i);
    }
    for(int i=1;i<=n;++i) a[n+1][i]=1.0; a[n+1][n+2]=1.0;
}

inline void Gauss(int n){
    for(int i=1,q;i<=n;++i){
        q=i;
        for(int j=i+1;j<=n;++j)
        if(fabs(a[j][i])>fabs(a[q][i])) q=j;
        swap(a[i],a[q]);
        for(int j=n+1;j>=i;--j)
        a[i][j]=a[i][j]/a[i][i];
        for(int j=i+1;j<=n;++j)
        for(int k=n+1;k>=i;--k)
        a[j][k]-=a[i][k]*a[j][i];
    }
    for(int i=n;i;--i){
        ans[i]=a[i][n+1];
        for(int j=i+1;j<=n;++j)
        ans[i]-=a[i][j]*ans[j];
    }
}

int main(){
    n=In(); m=In(); p[0]=1.0;
    for(int i=1;i<=m;++i) p[i]=0.5*p[i-1];
    for(int i=1;i<=n;++i){
        scanf("%s",str);
        for(int j=0;j<m;++j) s[i][j]=(str[j]=='H');
        Insert(i);
    }
    Get_fail(1); Get_a(); Gauss(n+1);
    for(int i=1;i<=n;++i) printf("%.10lf\n",ans[i]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/pkh68/p/10643492.html