【模板】普通平衡树(spaly)

题目描述

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

  1. 插入xx数
  2. 删除xx数(若有多个相同的数,因只删除一个)
  3. 查询xx数的排名(排名定义为比当前数小的数的个数+1+1。若有多个相同的数,因输出最小的排名)
  4. 查询排名为xx的数
  5. xx的前驱(前驱定义为小于xx,且最大的数)
  6. xx的后继(后继定义为大于xx,且最小的数)

输入输出格式

输入格式:

第一行为nn,表示操作的个数,下面nn行每行有两个数optopt和xx,optopt表示操作的序号( 1 \leq opt \leq 61opt6 )

输出格式:

对于操作3,4,5,6每行输出一个数,表示对应答案

代码

ch[N][2]:ch[x][0]代表 xx 的左儿子,ch[x][1]代表 xx 的右儿子。
val[N]:val[x]代表 xx 存储的值。
cnt[N]:cnt[x]代表 xx 存储的重复权值的个数。
par[N]:par[x]代表 xx 的父节点。
size[N]:size[x]代表 xx 子树下的储存的权值数(包括重复权值)。

#include<bits/stdc++.h>
#define inf 1<<30 
using namespace std; 
const int maxn=2e5+10; 
int ch[maxn][2],par[maxn],val[maxn],cnt[maxn],size[maxn];
int ncnt,root;//ncnt新建结点位置,root 表示根节点 
int n;
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    while(ch<='9'&&ch>='0'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return x*f;
}
bool chk(int x)
{
    return ch[par[x]][1]==x;
}
void pushup(int x)
{
    size[x]=size[ch[x][0]]+size[ch[x][1]]+cnt[x];
}
void rotate(int x)
{
    int y=par[x],z=par[y],k=chk(x),w=ch[x][k^1];
    ch[y][k]=w;par[w]=y;
    ch[z][chk(y)]=x;par[x]=z;
    ch[x][k^1]=y;par[y]=x;
    pushup(y);pushup(x);
}
void splay(int x,int goal=0)
{
    while(par[x]!=goal)
    {
        int y=par[x],z=par[y];
        if(z!=goal)
        {   //两个结点位置相同
            if(chk(x)==chk(y))rotate(y);
            else rotate(x); 
        }
        rotate(x);
    }
    if(!goal)root=x;
}
void insert(int x)
{
    int cur=root,p=0;//p记录当前节点 
    while(cur&&val[cur]!=x)
    p=cur,cur=ch[cur][x>val[cur]];
    if(cur)cnt[cur]++;
    else
    {
        cur=++ncnt;
        if(p)ch[p][x>val[p]]=cur; 
        ch[cur][0]=ch[cur][1]=0; 
        par[cur]=p;val[cur]=x;
        cnt[cur]=size[cur]=1; 
    }
    splay(cur);
}
void find(int x)//把某点旋到根节点 
{
    int cur=root;
    while(ch[cur][x>val[cur]]&&x!=val[cur])
    cur=ch[cur][x>val[cur]];//找到该节点 
    splay(cur);
}
int kth(int k)
{
    int cur=root;
    while(1)
    {
        if(ch[cur][0]&&k<=size[ch[cur][0]])
        cur=ch[cur][0];
        else if(k>size[ch[cur][0]]+cnt[cur])
        k-=size[ch[cur][0]]+cnt[cur],cur=ch[cur][1];
        else return cur;
    }
}
int pre(int x)
{
    find(x);
    if(val[root]<x)return root;
    int cur=ch[root][0];
    while(ch[cur][1])cur=ch[cur][1];
    return cur;
}
int succ(int x)
{
    find(x);
    if(val[root]>x)return root;
    int cur=ch[root][1];
    while(ch[cur][0])cur=ch[cur][0];
    return cur;
}
void remove(int x)
{
    int last=pre(x),next=succ(x);
    splay(last);splay(next,last);
    int del=ch[next][0];//表示要删的点 
    if(cnt[del]>1)
    cnt[del]--,splay(del);//更新size标记 
    else ch[next][0]=0; 
}
int main()
{
    n=read();
    insert(inf);
    insert(-inf);
    for(int i=1;i<=n;i++)
    {
        int op=read(),x=read();
        if(op==1)insert(x);
        else if(op==2)remove(x);
        else if(op==3)find(x),printf("%d\n",size[ch[root][0]]);
        else if(op==4)printf("%d\n",val[kth(x+1)]);
        else if(op==5)printf("%d\n",val[pre(x)]);
        else if(op==6)printf("%d\n",val[succ(x)]);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/DriverBen/p/10410426.html