[loj#3052] [十二省联考 2019] 春节十二响

题意简述

给定一棵 \(n\) 的节点的树,根为1,每个点有权值 \(M_i\)
要把树分成若干段,每段内不存在“祖先-后代”关系,定义一个段的大小为段中点 \(M_i\) 的最大值。
求所有段的大小之和的最小值。
\(n\leq 2\times 10^5\)


想法

想法一:奇怪的贪心

每次找到全树中没分到段内的有最大权值的点 \(u\) ,则 \(M_u\) 无论如何都要加到最终答案中。
那么下一步贪心就是找到非 \(u\) 的祖先且非 \(u\) 的后代的点中有最大权值的点 \(v\) ,与 \(u\) 分到一个段中。
然后再找与 \(u\)\(v\) 都不存在 “祖先-后代”关系的有最大权值的点加入该段中……以此类推,直到找不到可加入的点,这一个段结束。

要证的话,大概就交换一下。
假设最优解中 \(v\)\(u\) 不在同一个段中,将 \(v\)\(u\) 所在段中与 \(v\) 不兼容的点交换位置,之后仍满足要求且不会更差。

怎样找非“祖先-后代”关系的点中权值最大的点呢?
树剖+线段树。
段中每加入一个点,就把它的祖先和后代在线段树中“盖住”,一个段结束后再统一把所有“盖子”都去掉。
很坑,细节极多,实在需要好好注意。(我在细节写炸后还怀疑是算法错了呢 \(qwq\)

想法二:靠谱一些的贪心

树上问题,先考虑子树。
假设已求出子树最优情况下各个段的大小,在对子树进行合并时,显然各自树的最大段合成一段,次大段合成一段,以此类推……
(我也不知道怎么证,但看起来就很靠谱)
对每个子树搞个堆,启发式合并就可以了。


总结

树上问题,考虑子树……


代码

启发式合并

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<queue>

using namespace std;

int read(){
    int x=0;
    char ch=getchar();
    while(!isdigit(ch)) ch=getchar();
    while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
    return x;
}

const int N = 200005;
typedef long long ll;

struct node{
    int v;
    node *nxt;
}pool[N],*h[N];
int cnt;
void addedge(int u,int v){
    node *p=&pool[++cnt];
    p->v=v;p->nxt=h[u];h[u]=p;
}

int n,M[N],f[N],tl;
int st[N];
priority_queue<int> q[N];
void merge(int u,int v){
    if(q[u].size()<q[v].size()) swap(q[u],q[v]);
    tl=0;
    while(q[v].size()){
        st[tl++]=max(q[v].top(),q[u].top());
        q[v].pop(); q[u].pop();
    }
    for(int i=0;i<tl;i++) q[u].push(st[i]);
}
void work(int u){
    int v;
    for(node *p=h[u];p;p=p->nxt)
        work(v=p->v),merge(u,v);
    q[u].push(M[u]);
}

int main()
{
    n=read(); 
    for(int i=1;i<=n;i++) M[i]=read();
    for(int i=2;i<=n;i++) f[i]=read(),addedge(f[i],i);
    work(1);
    ll ans=0;
    for(;q[1].size();) ans+=q[1].top(),q[1].pop();
    printf("%lld\n",ans);
    
    return 0;
}

树剖+线段树

#include<cstdio>
#include<iostream>
#include<algorithm>

using namespace std;

int read(){
    int x=0;
    char ch=getchar();
    while(!isdigit(ch)) ch=getchar();
    while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
    return x;
}

const int N = 200005;
typedef pair<int,int> Pr;
typedef long long ll;

int n,M[N],f[N];

struct node{
    int v;
    node *nxt;
}pool[N],*h[N];
int cnt1;
void addedge(int u,int v){
    node *p=&pool[++cnt1];
    p->v=v;p->nxt=h[u];h[u]=p;
}

int dfn[N],tot,son[N],sz[N],top[N],re[N],out[N];
void dfs1(int u){
    int v,Mson=0;
    sz[u]=1;
    for(node *p=h[u];p;p=p->nxt){
        dfs1(v=p->v);
        sz[u]+=sz[v];
        if(sz[v]>Mson) son[u]=v,Mson=sz[v];
    }
}
void dfs2(int u){
    int v=son[u];
    if(v){
        top[v]=top[f[v]];
        dfn[v]=++tot;
        re[tot]=v;
        dfs2(v);
    }
    for(node *p=h[u];p;p=p->nxt)
        if(!dfn[v=p->v]){
            top[v]=v;
            dfn[v]=++tot;
            re[tot]=v;
            dfs2(v);
        }
    out[u]=tot;
}

int cnt,root,ch[N*2][2],cov[N*2],use[N*2];
Pr ori[N*2],mx[N*2];
void build(int x,int l,int r){
    cov[x]=0;
    if(l==r) { mx[x]=Pr(M[re[l]],re[l]); ori[x]=mx[x]; return; }
    int mid=(l+r)>>1;
    build(ch[x][0]=++cnt,l,mid);
    build(ch[x][1]=++cnt,mid+1,r);
    mx[x]=max(mx[ch[x][0]],mx[ch[x][1]]);
    ori[x]=mx[x];
} 
void push(int x){
    if(!x) return;
    use[x]=1; cov[x]=1; mx[x]=Pr(0,0);
}
void pushdown(int x){
    if(!cov[x]) return;
    push(ch[x][0]); push(ch[x][1]);
    cov[x]=0;
}
void modify(int x,int l,int r,int L,int R){
    use[x]=1;
    if(L<=l && r<=R) { push(x); return; }
    pushdown(x);
    int mid=(l+r)>>1;
    if(L<=mid) modify(ch[x][0],l,mid,L,R);
    if(R>mid) modify(ch[x][1],mid+1,r,L,R);
    mx[x]=max(mx[ch[x][0]],mx[ch[x][1]]);
}
void recover(int x,int l,int r){
    if(!use[x]) return;
    use[x]=0; cov[x]=0; mx[x]=ori[x];
    if(l==r) return;
    int mid=(l+r)>>1;
    recover(ch[x][0],l,mid); recover(ch[x][1],mid+1,r);
}
Pr Max(int x,int l,int r,int L,int R){
    if(L<=l && r<=R) return mx[x];
    pushdown(x);
    int mid=(l+r)>>1;
    Pr ret(0,0);
    if(L<=mid) ret=max(ret,Max(ch[x][0],l,mid,L,R));
    if(R>mid) ret=max(ret,Max(ch[x][1],mid+1,r,L,R));
    return ret;
}
void change(int x,int l,int r,int c){
    use[x]=1;
    if(l==r) { mx[x]=ori[x]=Pr(0,0); return; }
    pushdown(x);
    int mid=(l+r)>>1;
    if(c<=mid) change(ch[x][0],l,mid,c);
    else change(ch[x][1],mid+1,r,c);
    mx[x]=max(mx[ch[x][0]],mx[ch[x][1]]);
    ori[x]=max(ori[ch[x][0]],ori[ch[x][1]]);
}
void jump(int x){
    modify(root,1,n,dfn[x],out[x]);/**/
    while(x){
        modify(root,1,n,dfn[top[x]],dfn[x]);
        x=f[top[x]];
    }
}

int main()
{
    n=read();
    for(int i=1;i<=n;i++) M[i]=read();
    for(int i=2;i<=n;i++) f[i]=read(),addedge(f[i],i);
    
    dfs1(1);
    top[1]=1; dfn[1]=++tot; re[tot]=1; dfs2(1);
    build(root=++cnt,1,n);
    
    ll ans=0;
    int t=n;
    while(t){
        Pr w=Max(root,1,n,1,n);
        ans+=w.first;
        change(root,1,n,dfn[w.second]); t--; /*dfn*/
        jump(w.second); 
        for(;;){
            w=Max(root,1,n,1,n);
            if(w.second==0) break;
            change(root,1,n,dfn[w.second]); t--; /*dfn*/
            jump(w.second);
        }
        recover(root,1,n);
    }
    printf("%lld\n",ans);
    
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/lindalee/p/12458565.html
今日推荐