LOJ2541 PKUWC2018 猎人杀 容斥、生成函数、分治

传送门


首先,每一次有一个猎人死亡之后\(\sum w\)会变化,计算起来很麻烦,所以考虑在某一个猎人死亡之后给其打上标记,仍然计算他的\(w\),只是如果打中了一个打上了标记的人就重新选择。这样对应于每一个人的概率仍然是一样的,而\(\sum w\)在计算的过程中不会变。

因为要求最后死的概率,似乎不是很好求,考虑容斥。枚举一个集合\(S\),我们强制集合\(S\)中的猎人在\(1\)号猎人死亡之后死亡。设集合\(S\)中所有猎人的\(w\)之和为\(A\),所有猎人的\(w\)之和为\(sum\),那么集合\(S\)能够产生的贡献为\((-1) ^ {|S|} \times \frac{w_1}{sum} \times \sum\limits_{i=0} ^ {\infty} (1 - \frac{A + w_1}{sum})^i\)

注意到后面是一个无穷递减等比数列,那么\(\sum\limits_{i=0} ^ {\infty} (1 - \frac{A + w_1}{sum})^i = \frac{1}{1 - (1 - \frac{A + w_1}{sum})} = \frac{sum}{A + w_1}\),那么原式等于\(-1^{|S|} \times \frac{w_1}{A + w_1}\)

那么我们只需要计算每一个集合的\(A\)就可以了。

注意到对于\(A\)的计算,实质是一个\(01\)背包。但是直接\(DP\)肯定复杂度爆炸,考虑生成函数求解

\(i\)个猎人的生成函数为\(-x^{w_i} + 1\)\(-x^{w_i}\)表示选择第\(i\)个猎人,但是集合的贡献乘上\(-1\)\(+1\)表示不选择第\(i\)个猎人。然后分治+\(NTT\)求解,我们就可以得到对于所有的\(A\)\(\frac{w_1}{A + w_1}\)前面的系数了。

总的复杂度为\(O(n\ log^2n)\)

#include<bits/stdc++.h>
#define ll long long
#define mid ((l + r) >> 1)
//This code is written by Itst
using namespace std;

inline int read(){
    int a = 0;
    char c = getchar();
    bool f = 0;
    while(!isdigit(c)){
        if(c == '-')
            f = 1;
        c = getchar();
    }
    while(isdigit(c)){
        a = (a << 3) + (a << 1) + (c ^ '0');
        c = getchar();
    }
    return f ? -a : a;
}

const int MOD = 998244353 , G = 3 , INV = 332748118 , MAXN = 2e5 + 10;
int val[MAXN] , dir[MAXN] , N , need , inv_need;
vector < int > v[MAXN];

inline int poww(ll a , int b){
    int times = 1;
    while(b){
        if(b & 1)
            times = times * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return times;
}

inline void NTT(vector < int > &arr , int type){
    while(arr.size() < need)
        arr.push_back(0);
    for(int i = 1 ; i < need ; ++i)
        if(i < dir[i])
            swap(arr[i] , arr[dir[i]]);
    for(int i = 1 ; i < need ; i <<= 1){
        int wn = poww(type == 1 ? G : INV , (MOD - 1) / (i << 1));
        for(int j = 0 ; j < need ; j += i << 1){
            ll w = 1;
            for(int k = 0 ; k < i ; ++k , w = w * wn % MOD){
                int x = arr[j + k] , y = arr[i + j + k] * w % MOD;
                arr[j + k] = x + y >= MOD ? x + y - MOD : x + y;
                arr[i + j + k] = x - y < 0 ? x - y + MOD : x - y;
            }
        }
    }
}

inline void solve(int l , int r){
    need = 1;
    while(need <= v[l].size() + v[r].size())
        need <<= 1;
    inv_need = poww(need , MOD - 2);
    for(int i = 1 ; i < need ; ++i)
        dir[i] = (dir[i >> 1] >> 1) | (i & 1 ? need >> 1 : 0);
    NTT(v[l] , 1);
    NTT(v[r] , 1);
    for(int i = 0 ; i < need ; ++i)
        v[l][i] = 1ll * v[l][i] * v[r][i] % MOD;
    NTT(v[l] , -1);
    for(int i = 0 ; i < need ; ++i)
        v[l][i] = 1ll * v[l][i] * inv_need % MOD;
    while(v[l][v[l].size() - 1] == 0)
        v[l].erase(--v[l].end());
}

int main(){
#ifndef ONLINE_JUDGE
    freopen("in" , "r" , stdin);
    //freopen("out" , "w" , stdout);
#endif
    N = read();
    if(N == 1){
        puts("1");
        return 0;
    }
    for(int i = 1 ; i <= N ; ++i){
        val[i] = read();
        if(i != 1){
            v[i].push_back(1);
            while(v[i].size() < val[i])
                v[i].push_back(0);
            v[i].push_back(MOD - 1);
        }
    }
    int ans = 0;
    for(int i = 1 ; i < N ; i <<= 1)
        for(int j = 2 ; j + i <= N ; j += i << 1){
            solve(j , j + i);
            vector < int >().swap(v[j + i]);
        }
    for(int i = 0 ; i < v[2].size() ; ++i)
        ans = (ans + 1ll * poww(i + val[1] , MOD - 2) * v[2][i]) % MOD;
    cout << 1ll * ans * val[1] % MOD;
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Itst/p/10274435.html
今日推荐