【LOJ2541】猎人杀(PKUWC2018)-容斥+级数+分治NTT

测试地址:猎人杀
做法:本题需要用到容斥+级数+分治NTT。
要求 1 号最后一个被射杀,其实就是要求所有人都不能在 1 号后被射杀。这种要求全部条件满足求方案数/概率的情况,就要考虑容斥,即枚举一个集合 S ,计算强制这 S 个人在 1 号后被射杀的概率 p ( S ) ,那么答案就等于:
a n s = S ( 1 ) | S | p ( S )
可是由于游戏的每一步中,概率的分母都不同, p ( S ) 很难计算,怎么办呢?我们需要对游戏做出一些转化:一个人被射杀后,他仍然参与概率的计算,但如果射中了已经被射杀的人,就再射一次,显然这和原来的游戏是等价的。这样的话,令 s u m ( S ) = i S w i , W = i = 1 n w i ,我们有:
p ( S ) = i = 0 ( 1 w 1 + s u m ( S ) W ) i w 1 W
w 1 W 提出来后,剩下的和式是一个无穷级数,因为 0 < 1 w 1 + s u m ( S ) W < 1 ,所以这个级数是收敛的,那么它就等于前缀和数列的极限。我们有公式:
i = 0 x i = 1 1 x
所以有:
p ( S ) = w 1 W W w 1 + s u m ( S ) = w 1 w 1 + s u m ( S )
于是有:
a n s = w 1 S ( 1 ) | S | w 1 + s u m ( S )
虽然我们极大地简化了所求的式子,但是这个还是不太好求。这时我们注意到一个条件: i = 1 n w i 10 5 ,这启发我们分开计算每种分母的贡献。于是我们构造一个生成函数,其中 x i 项的系数就表示分母为 i 的数对答案贡献的分子,我们怎么算出这个生成函数呢?注意到这就等于 x w 1 i = 2 n ( x 0 x w i ) ,于是分治NTT求出后面的部分即可。这里的分治NTT就是单纯的分治+NTT,而不是CDQ分治+NTT。于是我们就解决了这一题,时间复杂度为 O ( W log W log n )
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const ll g=3;
int n,sum,rev[200010],cnt=0,siz[30];
ll w[200010],A[30][200010];

ll power(ll a,ll b)
{
    ll s=1,ss=a;
    b=(b+mod-1)%(mod-1);
    while(b)
    {
        if (b&1) s=s*ss%mod;
        ss=ss*ss%mod;b>>=1;
    }
    return s;
}

ll NTT(ll *a,int n,int type)
{
    for(int i=0;i<n;i++)
        if (i<rev[i]) swap(a[i],a[rev[i]]);
    for(int mid=1;mid<n;mid<<=1)
    {
        ll W=power(g,type*(mod-1)/(mid<<1));
        for(int l=0,G=(mid<<1);l<n;l+=G)
        {
            ll w=1;
            for(int k=0;k<mid;k++,w=w*W%mod)
            {
                ll x=a[l+k],y=w*a[l+mid+k]%mod;
                a[l+k]=(x+y)%mod;
                a[l+mid+k]=(x-y+mod)%mod;
            }
        }
    }
    if (type==-1)
    {
        ll inv=power(n,mod-2);
        for(int i=0;i<n;i++)
            a[i]=a[i]*inv%mod;
    }
}

void solve(int l,int r)
{
    if (l==r)
    {
        cnt++;
        A[cnt][0]=1,A[cnt][w[l]]=mod-1;
        siz[cnt]=w[l];
        for(int i=1;i<w[l];i++)
            A[cnt][i]=0;
        return;
    }

    int mid=(l+r)>>1;
    solve(l,mid);
    solve(mid+1,r);

    int bit=0,x=1,a=cnt-1,b=cnt,tot=siz[a]+siz[b];
    while(x<=tot) bit++,x<<=1;
    rev[0]=0;
    for(int i=1;i<x;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    for(int i=siz[a]+1;i<x;i++)
        A[a][i]=0;
    for(int i=siz[b]+1;i<x;i++)
        A[b][i]=0;
    NTT(A[a],x,1),NTT(A[b],x,1);
    for(int i=0;i<x;i++)
        A[a][i]=A[a][i]*A[b][i]%mod;
    NTT(A[a],x,-1);

    cnt--;
    siz[cnt]=tot;
}

int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        scanf("%lld",&w[i]);
        sum+=w[i];
    }

    if (n==1) printf("1");
    else
    {
        solve(2,n);
        ll ans=0;
        for(int i=0;i<=sum;i++)
            ans=(ans+A[1][i]*power(w[1]+i,mod-2))%mod;
        printf("%lld",ans*w[1]%mod); 
    }

    return 0; 
}

猜你喜欢

转载自blog.csdn.net/maxwei_wzj/article/details/80714129