PKUWC2018 minimax

PKUWC2018 minimax

题面描述

一个大小为\(n\)的二叉树,每个叶子结点都有一个互不相同的权值。

每个非叶子结点\(x\)都有一个概率\(p_x\),表示它有\(p_x\)的概率选择它所有儿子权值的最大值,\(1-p_x\)的概率选择它所有儿子权值的最小值。

求出最后根节点取每个权值的概率。

最后把答案以某种方式压缩输出。

答案对\(998244353\)取模。

思路

线段树合并。

如果当前节点为的权值为\(x\),则含\(x\)的子树必须选择\(x\)

要么总体选择最大值,其他子树权值小于\(x\)

要么总体选择最小值,其他子树权值大于\(x\)

维护一个区间和和区间乘法的标记即可。

代码

#include<bits/stdc++.h>
using namespace std;
const int sz=3e5+7;
const int mod=998244353;
int n,m;
int cnt,ans;
int rt[sz];
int f[sz];
int p[sz];
int a[sz];
int inv[sz];
int tr[sz*40],tag[sz*40];
int ls[sz*40],rs[sz*40];
int c[sz][2],t[sz];
void init(){
    inv[1]=1;
    for(int i=2;i<sz;i++)
        inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
}
void update(int &o,int l,int r,int pos,int v){
    if(!o) o=++cnt,tag[o]=1;
    if(l==r) return (void)(tr[o]=v);
    int mid=(l+r)>>1;
    if(pos<=mid) update(ls[o],l,mid,pos,v);
    else update(rs[o],mid+1,r,pos,v);
    tr[o]=(tr[ls[o]]+tr[rs[o]])%mod;
}
void pd(int o){
    if(ls[o]){
        tag[ls[o]]=1ll*tag[ls[o]]*tag[o]%mod;
        tr[ls[o]]=1ll*tr[ls[o]]*tag[o]%mod;
    }
    if(rs[o]){
        tag[rs[o]]=1ll*tag[rs[o]]*tag[o]%mod;
        tr[rs[o]]=1ll*tr[rs[o]]*tag[o]%mod;
    }
    tag[o]=1;
}
int merge(int o1,int o2,int l,int r,int lx1,int rx1,int lx2,int rx2,int x){
    if(!o1&&!o2) return 0;
    int p1=1ll*p[x]*inv[10000]%mod;
    int p2=(mod+1-p1)%mod;
    if(!o1){
        o1=o1^o2;
        int sum=(1ll*lx1*p1%mod+1ll*rx1*p2%mod)%mod;
        tr[o1]=1ll*tr[o1]*sum%mod;
        tag[o1]=1ll*tag[o1]*sum%mod;
        return o1;
    }
    if(!o2){
        o1=o1^o2;
        int sum=(1ll*lx2*p1%mod+1ll*rx2*p2%mod)%mod;
        tr[o1]=1ll*tr[o1]*sum%mod;
        tag[o1]=1ll*tag[o1]*sum%mod;
        return o1;
    }
    if(tag[o1]>1) pd(o1);
    if(tag[o2]>1) pd(o2);
    int mid=(l+r)>>1;
    int suml1=tr[ls[o1]],sumr1=tr[rs[o1]];
    int suml2=tr[ls[o2]],sumr2=tr[rs[o2]];
    ls[o1]=merge(ls[o1],ls[o2],l,mid,lx1,(rx1+sumr1)%mod,lx2,(rx2+sumr2)%mod,x);
    rs[o1]=merge(rs[o1],rs[o2],mid+1,r,(lx1+suml1)%mod,rx1,(lx2+suml2)%mod,rx2,x);
    tr[o1]=(tr[ls[o1]]+tr[rs[o1]])%mod;
    return o1;
}
void dfs(int x){
    if(!t[x]) return (void)(update(rt[x],1,m,p[x],1));
    if(c[x][0]) dfs(c[x][0]);
    if(c[x][1]) dfs(c[x][1]);
    rt[x]=rt[c[x][0]];
    if(t[x]==2) rt[x]=merge(rt[x],rt[c[x][1]],1,m,0,0,0,0,x);
}
void getans(int o,int l,int r){
    if(l==r) return (void)(ans=(ans+1ll*a[l]*l%mod*tr[o]%mod*tr[o]%mod)%mod);
    if(tag[o]>1) pd(o);
    int mid=(l+r)>>1;
    getans(ls[o],l,mid);
    getans(rs[o],mid+1,r);
}
int main(){
    init();
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%d",&f[i]);
        if(i==1) continue;
        c[f[i]][t[f[i]]++]=i;
    }
    for(int i=1;i<=n;i++){
        scanf("%d",&p[i]);
        if(!t[i]) a[++m]=p[i];
    }
    sort(a+1,a+m+1);
    for(int i=1;i<=n;i++){
        if(t[i]) continue;
        p[i]=lower_bound(a+1,a+m+1,p[i])-a;
    }
    dfs(1);
    getans(rt[1],1,m);
    printf("%d\n",ans);
}

猜你喜欢

转载自www.cnblogs.com/river-flows-in-you/p/11984167.html
今日推荐