[PKUWC2018] Minimax

Description

给定一棵 \(n\) 个节点的树,每个节点最多有两个子节点。

如果 \(x\) 是叶子,则给定 \(x\) 的权值;否则,它的权值有 \(p_x\) 的概率是它子节点中权值的较大值,\(1-p_x\) 的概率是它子节点中权值的较小值。保证叶子结点权值互不相同。

求根节点所有可能的权值的概率。模 \(998244353\)

Solution

嗯比较自然的一道题。

\(f_{i,x}\) 为结点 \(i\) 权值为 \(x\) 的概率,\(l,r\) 分别是点 \(i\) 的左右子树,则有(假设权值 \(x\)\(l\) 中出现):

\[f_{i,x}=\sum_{j=1}^{x-1}f_{l,x}\cdot f_{r,j}\cdot p_i+\sum_{j=x+1}^n f_{i,x}\cdot f_{r,j}\cdot (1-p_i)\]

发现这上是一个合并的过程,可以拿线段树合并做。

中间维护两棵树的前缀和后缀和,以及打好标记即可。

Code

LOJ格式化代码真好玩

我能玩一天

放上被LOJ格式化之后的代码

#include <bits/stdc++.h>
using std::max;
using std::min;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int, int>
#define all(A) A.begin(), A.end()
#define mp(A, B) std::make_pair(A, B)
#define int long long
const int N = 3e5 + 5;
const int M = N * 20;
const int mod = 998244353;

int sum[M], flag[M], inv;
int val[N], g[N], is[N], ans;
int n, cnt, leaf, head[N], lef;
int ch[M][2], tot, len, rt[N];

struct Edge {
    int to, nxt;
} edge[N << 1];

#define ls ch[x][0]
#define rs ch[x][1]

void pushup(int x) { sum[x] = (sum[ls] + sum[rs]) % mod; }

void pushdown(int x) {
    if (flag[x] != 1) {
        (flag[ls] *= flag[x]) %= mod;
        (flag[rs] *= flag[x]) %= mod;
        (sum[ls] *= flag[x]) %= mod;
        (sum[rs] *= flag[x]) %= mod;
        flag[x] = 1;
    }
}

void modify(int &x, int l, int r, int ql) {
    x = ++tot;
    flag[x] = 1;
    if (l == r)
        return sum[x] = 1, void();
    int mid = l + r >> 1;
    ql <= mid ? modify(ls, l, mid, ql) : modify(rs, mid + 1, r, ql);
    pushup(x);
}

#undef ls
#undef rs

int merge(int x, int y, int aqzh, int ahzh, int bqzh, int bhzh, int pi) {
    if (!x and !y)
        return 0;
    if (!x) {
        pushdown(y);
        (sum[y] *= ahzh * (1 - pi + mod) % mod + aqzh * pi % mod) %= mod;
        (flag[y] *= ahzh * (1 - pi + mod) % mod + aqzh * pi % mod) %= mod;
        return y;
    }
    if (!y) {
        pushdown(x);
        (sum[x] *= bhzh * (1 - pi + mod) % mod + bqzh * pi % mod) %= mod;
        (flag[x] *= bhzh * (1 - pi + mod) % mod + bqzh * pi % mod) %= mod;
        return x;
    }
    int now = ++tot;
    flag[now] = 1;
    pushdown(x), pushdown(y);
    int a = sum[ch[x][0]], b = sum[ch[y][0]];
    ch[now][0] =
        merge(ch[x][0], ch[y][0], aqzh, (ahzh + sum[ch[x][1]]) % mod, bqzh, (bhzh + sum[ch[y][1]]) % mod, pi);
    ch[now][1] = merge(ch[x][1], ch[y][1], (aqzh + a) % mod, ahzh, (bqzh + b) % mod, bhzh, pi);
    pushup(now);
    return now;
}

void add(int x, int y) {
    edge[++cnt].to = y;
    edge[cnt].nxt = head[x];
    head[x] = cnt;
}

int ksm(int a, int b, int ans = 1) {
    while (b) {
        if (b & 1)
            ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}

int getint() {
    int X = 0, w = 0;
    char ch = getchar();
    while (!isdigit(ch)) w |= ch == '-', ch = getchar();
    while (isdigit(ch)) X = X * 10 + ch - 48, ch = getchar();
    if (w)
        return -X;
    return X;
}

void dfs(int now) {
    if (!is[now])
        return;
    int tot = 0;
    for (int i = head[now]; i; i = edge[i].nxt) {
        int to = edge[i].to;
        tot++;
        dfs(to);
    }
    if (tot == 1) {
        for (int i = head[now]; i; i = edge[i].nxt) {
            int to = edge[i].to;
            rt[now] = rt[to];
        }
    } else {
        tot = 0;
        int a, b;
        for (int i = head[now]; i; i = edge[i].nxt) {
            int to = edge[i].to;
            tot == 1 ? b = to : a = to, tot++;
        }
        rt[now] = merge(rt[a], rt[b], 0, 0, 0, 0, val[now] * inv % mod);
    }
}

void dfs2(int now, int l, int r) {
    if (!now)
        return;
    pushdown(now);
    if (l == r)
        return (ans += (lef + 1) * g[l] % mod * sum[now] % mod * sum[now] % mod) %= mod, lef++, void();
    int mid = l + r >> 1;
    dfs2(ch[now][0], l, mid);
    dfs2(ch[now][1], mid + 1, r);
}

signed main() {
    n = getint();
    getint();
    inv = ksm(10000, mod - 2);
    for (int i = 2; i <= n; i++) {
        int x = getint();
        add(x, i);
        is[x] = 1;
    }
    for (int i = 1; i <= n; i++) {
        val[i] = getint();
        if (!is[i])
            g[++len] = val[i];
    }
    std::sort(g + 1, g + 1 + len);
    len = std::unique(g + 1, g + 1 + len) - g - 1;
    for (int i = 1; i <= n; i++) {
        if (!is[i]) {
            val[i] = std::lower_bound(g + 1, g + 1 + len, val[i]) - g;
            modify(rt[i], 1, len, val[i]);
        }
    }
    dfs(1);
    dfs2(rt[1], 1, len);
    printf("%lld\n", ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/YoungNeal/p/10300660.html
今日推荐