「PKUWC2018」猎人杀(分治NTT+概率期望)

Description

猎人杀是一款风靡一时的游戏“狼人杀”的民间版本,他的规则是这样的:

一开始有 \(n\) 个猎人,第 \(i\) 个猎人有仇恨度 \(w_i\) ,每个猎人只有一个固定的技能:死亡后必须开一枪,且被射中的人也会死亡。

然而向谁开枪也是有讲究的,假设当前还活着的猎人有 \([i_1,i_2,...,i_m]\),那么有 \(\frac{w_{i_k}}{\sum_{j=1}^nw_{i_j}}\) 的概率是向猎人 \(k\) 开枪。

一开始第一枪由你打响,目标的选择方法和猎人一样(即有 \(\frac{w_i}{\sum_{j=1}^nw_j}\) 的概率射中第 \(i\) 个猎人)。由于开枪导致的连锁反应,所有猎人最终都会死亡,现在 \(1\) 号猎人想知道它是最后一个死的的概率。

答案对 \(998244353\) 取模。

【输入格式】
第一行一个正整数 \(n\)

第二行 \(n\) 个正整数,第 \(i\) 个正整数表示 \(w_i\)

【输出格式】
输出一个非负整数表示答案。

【输入样例】

3
1 1 2

【输出样例】

915057324

【样例解释】
答案是 \(\frac{2}{4}×\frac{1}{2}+\frac{1}{4}×\frac{2}{3}=\frac{5}{12}\)

【数据规模与约定】
对于 \(10\%\) 的数据,有 \(1\leq n\leq 10\)

对于 \(30\%\) 的数据,有 \(1\leq n\leq 20\)

对于 \(50\%\) 的数据,有 \(1\leq \sum\limits_{i=1}^{n}w_i\leq 5000\)

另有 \(10\%\) 的数据,满足 \(1\leq w_i\leq 2\),且 \(w_1=1\)

另有 \(10\%\) 的数据,满足 \(1\leq w_i\leq 2\),且 \(w_1=2\)

对于 \(100\%\) 的数据,有 \(w_i>0\),且 \(1\leq \sum\limits_{i=1}^{n}w_i \leq 100000\)

Solution

考虑容斥,即枚举强制在 \(1\) 号之后死的人。设 \(T\) 为枚举到的人的集合,\(S\)\(T\) 中的 \(w_i\) 之和。

考虑怎么求 \(T\) 中的人都在 \(1\) 号之后死的概率。可以将它们合并成 \(0\) 号猎人,\(w_0=S\)。那么现在 \(\lceil\) \(0\) 号在 \(1\) 号之后死的概率 \(\rfloor\) 就是 \(\lceil\) \(T\) 中的人都在 \(1\) 号之后死的概率 \(\rfloor\)。显然 \(0\) 号和 \(1\) 号谁先死不受其它猎人影响,那么 \(\lceil\) \(0\) 号在 \(1\) 号之后死的概率 \(\rfloor\) 就是 \(\frac{w_1}{S+w_1}\),所以 \(\lceil\) \(T\) 中的人都在 \(1\) 号之后死的概率 \(\rfloor\) 也是 \(\frac{w_1}{S+w_1}\)

集合 \(T\) 对答案的贡献为 \((-1)^{|T|}×\frac{w_1}{S+w_1}\)

发现 \(\sum w_i \leq 10^5\),考虑对于每个 \(S\),求出 \(b_S\) 表示满足\(w_i\) 之和为 \(S\) 的集合 \(T\)\((-1)^T\) 之和。 那么 \(ans=\sum b_S×\frac{w_1}{S+w_1}\)

显然 \(b_S\) 就是多项式 \(\Pi _{i=2}^n(1-x^{w_i})\)\(x^S\) 项的系数,分治 \(\text{NTT}\) 即可。

\(m=\sum_{i=1}^n w_i\),时间复杂度 \(O(m \log m)\)

Code

#include <bits/stdc++.h>

using namespace std;

#define ll long long

template <class t>
inline void read(t & res)
{
    char ch;
    while (ch = getchar(), !isdigit(ch));
    res = ch - 48;
    while (ch = getchar(), isdigit(ch))
    res = res * 10 + (ch ^ 48);
}

const int e = 2e5 + 5, mod = 998244353;
vector<int>g[e];
int rev[e], n, ans, val[e], lim;

inline int ksm(int x, int y)
{
    int res = 1;
    while (y)
    {
        if (y & 1) res = (ll)res * x % mod;
        y >>= 1;
        x = (ll)x * x % mod;
    }
    return res;
}

inline void upt(int &x, int y)
{
    x = y;
    if (x >= mod) x -= mod;
}

inline void fft(int *a, int n, int op)
{
    int i, j, k, r = (op == 1 ? 3 : (mod + 1) / 3);
    for (i = 0; i < n; i++)
    if (i < rev[i]) swap(a[i], a[rev[i]]);
    for (k = 1; k < n; k <<= 1)
    {
        int w0 = ksm(r, (mod - 1) / (k << 1));
        for (i = 0; i < n; i += (k << 1))
        {
            int w = 1;
            for (j = 0; j < k; j++)
            {
                int b = a[i + j], c = (ll)w * a[i + j + k] % mod;
                upt(a[i + j], b + c);
                upt(a[i + j + k], b + mod - c);
                w = (ll)w * w0 % mod;
            }
        }
    }
}

inline void modify(int *a, int *b, int la, int lb)
{
    int i;
    fft(a, lim, 1);
    fft(b, lim, 1);
    for (i = 0; i < lim; i++) a[i] = (ll)a[i] * b[i] % mod;
    fft(a, lim, -1);
    int tot = ksm(lim, mod - 2);
    for (i = 0; i < la + lb - 1; i++) a[i] = (ll)a[i] * tot % mod;
}

inline void solve(int l, int r)
{
    if (l >= r) return;
    int mid = l + r >> 1;
    solve(l, mid); solve(mid + 1, r);
    int i, la = g[l].size(), lb = g[mid + 1].size();
    static int a[e], b[e];
    int k = 0; lim = 1;
    while (lim < la + lb - 1) lim <<= 1, k++;
    for (i = 0; i < lim; i++)
    rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << k - 1), a[i] = b[i] = 0;
    for (i = 0; i < la; i++) a[i] = g[l][i];
    for (i = 0; i < lb; i++) b[i] = g[mid + 1][i];
    modify(a, b, la, lb);
    g[l].resize(la + lb - 1);
    for (i = 0; i < la + lb - 1; i++) g[l][i] = a[i];
}

int main()
{
    int i, sum = 0;
    read(n);
    for (i = 1; i <= n; i++) read(val[i]), sum += val[i];
    sum -= val[1];
    g[1].push_back(1);
    for (i = 2; i <= n; i++)
    {
        g[i].resize(val[i] + 1);
        g[i][0] = 1;
        g[i][val[i]] = mod - 1;
    }
    solve(1, n);
    for (i = val[1]; i <= sum + val[1]; i++)
    {
        int x = i - val[1], inv = ksm(i, mod - 2);
        ans = (ans + (ll)g[1][x] * inv) % mod;
    }
    ans = (ll)ans * val[1] % mod;
    cout << ans << endl;
    fclose(stdin);
    fclose(stdout);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/cyf32768/p/12196025.html