线段树维护区间合并——cf1285E

感觉自己的解法又是歪的

代码写的很乱。。要先找出一开始有多少段,然后计算删掉每条线段的贡献,求个最大值就可以

删每条线段的贡献可以用线段树区间合并来做

ps:正解其实很简单。。扫描一下就可以了

/*
先把所有线段覆盖到线段树上
然后对每一个线段[L,R],查询区间[L,R]有多少值>1的段即可 
*/
#include<bits/stdc++.h>
using namespace std;
#define N 800005
 
int n,x[N],L[N],R[N],tot;
 
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
struct Seg{
    Seg(){}
    int lazy,ltag,rtag,cnt;//左端是否>1,右端是否>1,区间>1的段数
}seg[N<<2]; 
void pushdown(int rt){
    seg[rt<<1].lazy+=seg[rt].lazy;
    seg[rt<<1|1].lazy+=seg[rt].lazy;
    seg[rt].lazy=0;
} 
void pushup(int rt){
    seg[rt].ltag=seg[rt<<1].ltag;
    seg[rt].rtag=seg[rt<<1|1].rtag;
    seg[rt].cnt=seg[rt<<1].cnt+seg[rt<<1|1].cnt;
    if(seg[rt<<1].rtag && seg[rt<<1|1].ltag)
        seg[rt].cnt--;
}
Seg merge(Seg a,Seg b){
    Seg res;
    res.cnt=a.cnt+b.cnt;
    res.ltag=a.ltag;
    res.rtag=b.rtag;
    if(a.rtag && b.ltag)res.cnt--;
    return res;
}
void build(int l,int r,int rt){
    seg[rt].cnt=seg[rt].lazy=seg[rt].ltag=seg[rt].rtag=0;
    if(l==r)return;
    int m=l+r>>1;
    build(lson);build(rson);
}
void update(int L,int R,int l,int r,int rt){
    if(L<=l && R>=r){seg[rt].lazy++;return;}
    pushdown(rt);
    int m=l+r>>1;
    if(L<=m)update(L,R,lson);
    if(R>m)update(L,R,rson);
}
void rollback(int l,int r,int rt){
    if(l==r){
        if(seg[rt].lazy>1)
            seg[rt].ltag=seg[rt].rtag=seg[rt].cnt=1;
        return;    
    }
    pushdown(rt);
    int m=l+r>>1;
    rollback(lson);rollback(rson);
    pushup(rt);
} 
Seg query(int L,int R,int l,int r,int rt){
    if(L<=l && R>=r)return seg[rt];
    int m=l+r>>1,flag=0;
    Seg res;
    if(L<=m)res=query(L,R,lson),flag=1;
    if(R>m){
        if(flag)
            res=merge(res,query(L,R,rson));
        else res=query(L,R,rson);
    }
    return res; 
}
void init(){
    tot=0;
}
void debug(int l,int r,int rt){
    cout<<l<<" "<<r<<" "<<rt<<" "<<seg[rt].cnt<<" "<<seg[rt].lazy<<" "<<seg[rt].rtag<<'\n';
    if(l==r)return;
    int m=l+r>>1;
    debug(lson);debug(rson); 
}
 
int cnt[N];
int main(){
    int T;cin>>T;
    while(T--){
        cin>>n;
        init();
        for(int i=1;i<=n;i++){
            scanf("%d%d",&L[i],&R[i]);
            x[++tot]=L[i];x[++tot]=R[i];    
        }
        sort(x+1,x+1+tot);
        tot=unique(x+1,x+1+tot)-x-1;
        for(int i=1;i<=n;i++){
            L[i]=lower_bound(x+1,x+1+tot,L[i])-x;
            R[i]=lower_bound(x+1,x+1+tot,R[i])-x;    
            L[i]<<=1;R[i]<<=1;
        }
        tot<<=1;
        build(1,tot,1);
        for(int i=1;i<=n;i++){
            update(L[i],R[i],1,tot,1);
            //debug(1,tot,1);
        }
        rollback(1,tot,1);//求出每段的1的个数 
        
        for(int i=1;i<=tot;i++)cnt[i]=0;
        for(int i=1;i<=n;i++)cnt[L[i]]++,cnt[R[i]+1]--;
        for(int i=1;i<=tot;i++)cnt[i]+=cnt[i-1];
        int len=0;
        for(int i=1;i<=tot;i++)
            if(cnt[i] && cnt[i-1]==0)len++;
        
        int Max=-0x3f3f3f3f;
        for(int i=1;i<=n;i++)
            Max=max(Max,query(L[i],R[i],1,tot,1).cnt-1);
        cout<<Max+len<<'\n'; 
    }
} 

猜你喜欢

转载自www.cnblogs.com/zsben991126/p/12273691.html