LOJ #2552. 「CTSC2018」假面

水平急剧下降.jpg。SB题调了一个小时233

首先显然我们为了回答最后的问题,需要一个\(p_{i,j}\)表示第\(i\)个人剩下\(j\)滴血的概率

然后我们在做结界时就可以确认一个的人死/没死的概率

考虑设某个人\(x\)活着的概率是\(a_x\),死亡的概率是\(d_x\),显然对于\(x\)的击中概率:

\[P(x)=a_x\times\sum_{i=0}^{k-1} \frac{g_{x,i}}{i+1}\]

其中\(g_{x,i}\)表示除去\(x\)后剩下\(i\)个人的概率,考虑我们搞一个\(f_{i,j}\)表示前\(i\)个人剩下\(j\)个的概率,显然:

\[f_{i,j}=f_{i-1,j}\times d_i+f_{i-1,j-1}\times a_x\]

如果我们每次隔离掉一个数在求出\(f_{k-1}\)那显然是可以的,但复杂度是\(O(n^3)\)

仔细一想如果我们先把所有的做到一起,然后消去\(x\)的影响,逆推一下就可以退出\(g_x\),复杂度变为\(O(n^2)\),足以通过此题

#include<cstdio>
#include<cstring>
#define RI register int
#define CI const int&
using namespace std;
const int N=205,mod=998244353;
int n,w[N],q,opt,x,y,m,id[N],p[N][N],f[N],g[N],inv[N],ans[N];
inline int sum(CI x,CI y)
{
    int t=x+y; return t>=mod?t-mod:t;
}
inline int sub(CI x,CI y)
{
    int t=x-y; return t<0?t+mod:t;
}
inline int quick_pow(int x,int p=mod-2,int mul=1)
{
    for (;p;p>>=1,x=1LL*x*x%mod) if (p&1) mul=1LL*mul*x%mod; return mul;
}
inline void attack(CI pos,CI pr)
{
    int ms=sub(1,pr); for (RI i=0;i<=w[pos];++i)
    {
        if (i) p[pos][i]=1LL*p[pos][i]*ms%mod;
        if (i!=w[pos]) p[pos][i]=sum(p[pos][i],1LL*p[pos][i+1]*pr%mod);
    }
}
inline void solve(void)
{
    RI i,j; for (memset(f,0,sizeof(f)),f[0]=i=1;i<=m;++i)
    {
        int d=p[id[i]][0],a=sub(1,d); for (j=m;~j;--j)
        f[j]=sum(1LL*f[j]*d%mod,j?1LL*f[j-1]*a%mod:0);
    }
    for (i=1;i<=m;++i)
    {
        int d=p[id[i]][0],a=sub(1,d); ans[i]=0; if (!d)
        for (j=0;j<m;++j) ans[i]=sum(ans[i],1LL*f[j+1]*inv[j+1]%mod); else
        {
            int ivd=quick_pow(d); for (j=0;j<m;++j)
            g[j]=1LL*sub(f[j],j?1LL*g[j-1]*a%mod:0)*ivd%mod,
            ans[i]=sum(ans[i],1LL*g[j]*inv[j+1]%mod);
        }
        ans[i]=1LL*ans[i]*a%mod;
    }
    for (i=1;i<=m;++i) printf("%d%c",ans[i]," \n"[i==m]);
}
int main()
{
    RI i,j; for (scanf("%d",&n),i=1;i<=n;++i) scanf("%d",&w[i]),p[i][w[i]]=1,inv[i]=quick_pow(i);
    for (scanf("%d",&q);q;--q)
    {
        scanf("%d",&opt); if (!opt) scanf("%d%d%d",&m,&x,&y),attack(m,1LL*x*quick_pow(y)%mod);
        else { for (scanf("%d",&m),i=1;i<=m;++i) scanf("%d",&id[i]); solve(); }
    }
    for (i=1;i<=n;++i)
    {
        int ret=0; for (j=1;j<=w[i];++j) ret=sum(ret,1LL*j*p[i][j]%mod); printf("%d%c",ret," \n"[i==n]);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/cjjsb/p/12222162.html