Bzoj 3224.普通平衡树 [ 权值线段树 ]

版权声明:博主的文章,请随意转载 https://blog.csdn.net/Acer12138/article/details/82431500

Description

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)

Input

第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)

Output

对于操作3,4,5,6每行输出一个数,表示对应答案

Sample Input

10
1 106465
4 1
1 317721
1 460929
1 644985
1 84185
1 89851
6 81968
1 492737
5 493598

Sample Output

106465
84185
492737

权值线段树 ,平衡树 ,Treap , AVL , 替罪羊树 , 树状数组 都可以写 博主提供一种权值线段树的

AC code:

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>

#define lson rt<<1
#define rson rt<<1|1

using namespace std;

typedef long long ll;
typedef pair<int,int>pii;

const int maxn = 1e5+50;
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;

struct segtree{
    int l,r;
    int sum;
}ss[maxn<<2];

struct Node{
    int op;
    ll w;
}node[maxn];

int cnt = 0;
ll vis[maxn];

void pushUp(int rt) {
    ss[rt].sum = (ss[lson].sum + ss[rson].sum);
}

void build(int l,int r,int rt) {
    ss[rt].l = l; ss[rt].r = r;
    ss[rt].sum = 0;
    if(l == r) return;
    int mid = (l+r)>>1;
    build(l,mid,lson); build(mid+1,r,rson);
}

void update(int pos,int val,int rt) {
    if ( ss[rt].l == ss[rt].r ) {
        ss[rt].sum += val; return;
    }
    int mid = (ss[rt].l + ss[rt].r) >> 1;
    if ( mid >= pos ) update(pos,val,lson);
    else update(pos,val,rson);
    pushUp(rt);
}

/*********************************
re代码 
re原因:当node[i].op 为 3 的时候tmp
有可能为1那么tmp - 1 为零 l > r 条件永远不满足 因此re
在这个函数里面加一句
if ( l > r ) return 0;也可以过
*********************************/

//int query(int l,int r,int rt) {
//    if ( ss[rt].l == l && ss[rt].r == r ) {
//        return ss[rt].sum;
//    }
//    int mid = (ss[rt].l + ss[rt].r) >> 1;
//    if ( mid >= r ) return query(l,r,lson);
//    else if ( l > mid ) return query(l,r,rson);
//    else return query(l,mid,lson) + query(mid+1,r,rson);
//}

/********************************
ac代码 
********************************/ 

int query(int l,int r,int rt) {
    if ( ss[rt].l >= l && ss[rt].r <= r ) {
        return ss[rt].sum;
    }
    int mid = (ss[rt].l + ss[rt].r) >> 1;
    int ans = 0;
    if ( mid >= l ) ans += query(l,r,lson);
    if ( r > mid ) ans += query(l,r,rson);
    return ans;
}

int Kth(int pos,int rt) {
    if ( ss[rt].l == ss[rt].r ) return ss[rt].l;
    int mid = (ss[rt].l + ss[rt].r ) >> 1;
    if ( ss[lson].sum >= pos ) {
        Kth(pos,lson);
    } else {
        Kth(pos-ss[lson].sum,rson);
    }
}

void Print(int l,int r,int rt) {
    if(l == r) {
        printf("ss[%d].sum = %d\n",ss[rt].l,ss[rt].sum);
        return;
    }
    int mid = (l+r)>>1;
    Print(l,mid,lson); Print(mid+1,r,rson);
}

int main(){

    int t; cin>>t;

    for (int i = 1;i<=t;i++) {
        scanf("%d %lld",&node[i].op, &node[i].w );
        if ( node[i].op != 4 ) vis[++cnt] = node[i].w;
    }

    sort(vis+1,vis+cnt+1);
    int len = unique(vis+1,vis+cnt+1) - vis; len--;
//    printf("len = %d\n",len);
//    for (int i = 1;i<=len;i++) printf("%lld ",vis[i]);
//    printf("\n");

    build(1,len,1);

    for (int i = 1;i<=t;i++) {
        int tmp;
        if ( node[i].op != 4 ) 
            tmp = lower_bound(vis+1,vis+len+1,node[i].w) - vis;
        //printf("^ ^ %d %lld\n",tmp,vis[tmp]);
        if(node[i].op == 1) {
            update(tmp,1,1);
            //Print(1,len,1);
        } else if ( node[i].op == 2 ) {
            update(tmp,-1,1);
        } else if ( node[i].op == 3 ) {
            printf("%d\n",query(1,tmp-1,1)+1);
        } else if ( node[i].op == 4 ) {
            //printf("node[i].w =  %lld, Kth() = %d\n",node[i].w,Kth(node[i].w,1));
            printf("%lld\n",vis[Kth(node[i].w,1)]);
        } else if ( node[i].op == 5 ) {
            //printf("node[i].w = %lld ,tmp = %d\n",node[i].w,tmp);
            int ans = query(1,tmp-1,1);
            int temp = Kth(ans,1);
            //Print(1,len,1);
            //printf("ans = %lld ,temp = %d\n",ans,temp);
            printf("%lld\n",vis[temp]);
        } else if ( node[i].op == 6 ) {
            int ans = query(1,tmp,1);
            int temp = Kth(ans+1,1);
            printf("%lld\n",vis[temp]);
        }
    }

    return 0;
}

猜你喜欢

转载自blog.csdn.net/Acer12138/article/details/82431500