P3369 平衡树treap模板

第一次学,调了两天

一个操作一个操作的讲

插入,这个没啥好说的,将一个点插入树中,定义一个insert(int &o,int v)函数,o传的引用,改变o值时,root也会跟着改变,先判断当前传入值是否存在值,即当前o是否等于0,等于0则没有用过,初始化一个prio=rand(),即给当前插入点随机附一个优先级,防止出题人对着你卡数据把你卡成一条链,用prio来旋转,初始当前点的size为1;

push_up操作,更改当前修改点的size,即子树大小,当前点的子树大小等于左儿子子树大小+右儿子子树大小+1

rotate旋转,向右旋转则把当前点接在他的左儿子的右子树上,来确保中序遍历不变,中序遍历的意思是优先走左子树走到叶子节点后开始回溯输出,输出结果一定是升序,自己手摸,向左旋转同理,转完之后更新两个节点的size

remove删除操作,先找到值为当前删除点的点的标号,然后先看当前点的左右子树哪个为空,然后就把他和另一个子树交换位置,直到当前点成为叶子节点,然后删掉它,如果都不为空,判断哪科子树优先级更小,就与那棵子树进行交换,直到换到叶子节点,过程中注意要保证中序遍历,随时更新size

大于等于的最小值,和线段树相似,随机从一个点进入,然后比较他的val与当前查询值得大小,小于则搜右子树,反之则搜索左子树然后和当前点取min

小于等于的最大值,与上一个操作相似,只不过把第二种情况换成取max

获取当前值的排名,随机进入,比较大小,小的话,则说明在当前点的前面,搜索右子树,加上左子树大小还有当前点的1,否则搜索右子树

获取排名为k的值,说白了求第k大,与上面实现相反,判断子树大小与排名的关系,bulabula,嗯,

上代码

//By Acer.Mo
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int MAX_N = 3e5 + 10;
const int inf = 0x3f3f3f3f;
struct node 
{
    int ch[2];
    int val, size, prio;
};
node t[MAX_N];
int pool_cur;
int del_pool[MAX_N], del_cur;
int root;
void init() 
{
    pool_cur = 1;
    root = 0;
}
void push_up(int o) 
{
//	cout<<"ooo="<<o<<"  lss="<<t[t[o].ch[0]].size<<"  rss="<< t[t[o].ch[1]].size <<endl;
    t[o].size = t[t[o].ch[0]].size + t[t[o].ch[1]].size + 1;
}
void rotate(int &o, int d) 
{
    int u = t[o].ch[d];
    t[o].ch[d] = t[u].ch[d ^ 1];
    t[u].ch[d ^ 1] = o;
    //t[u].size = t[o].size;
    push_up(o);
    push_up(u);
    o = u;
}
void insert(int &o, int v) 
{ 
    if (!o) 
    {
        o = pool_cur++;
        t[o].size=1;
        t[o].prio=rand();
        t[o].val = v;
        return ;
    }
    t[o].size++;
    int d = t[o].val < v;
    insert(t[o].ch[d], v);
    push_up(o);
    if (t[t[o].ch[d]].prio < t[o].prio) rotate(o, d);
    return ;
}
void remove(int &o, int v) 
{ 
    if (!o) return ;
    if (t[o].val == v) 
    {
        int u = o;
        if (t[o].ch[0]*t[o].ch[1]==0) 
        {
            o=t[o].ch[0]+t[o].ch[1];
            return ;
        }
        else 
        {
            int d = (t[t[o].ch[0]].prio<t[t[o].ch[1]].prio);
            rotate(o, d ^ 1);
            remove(t[o].ch[d], v);
        }
    } 
    else 
    {
        int d =(t[o].val < v);
        remove(t[o].ch[d], v);
    }
    push_up(o);
    return ;
}
int lower(int o,int v)
{
    if (o==0) return -1<<30;
    if (t[o].val<v) return max(t[o].val,lower(t[o].ch[1],v));
    else return lower(t[o].ch[0],v); 
}
int uper(int o,int v)
{
    if (!o) return inf;
    if (t[o].val>v) return min(t[o].val,uper(t[o].ch[0],v));
    else return uper(t[o].ch[1],v);
}
int getrank(int o,int v)
{
    if (!o) return 1;
    if (t[o].val>=v) return getrank(t[o].ch[0],v);
    return getrank(t[o].ch[1],v)+1+t[t[o].ch[0]].size;
}
int find_kth(int o, int k) 
{ 
    //cout<<"ooo="<<o<<" kkk="<<k<<endl;
    if (!o) return 0;
    int d = k - t[t[o].ch[0]].size;//cout<<"size="<<t[t[o].ch[0]].size<<endl;
    if (d <= 0) return find_kth(t[o].ch[0], k);
    if (d == 1) return t[o].val;
    else return find_kth(t[o].ch[1], d - 1);
}
int main() 
{
    init();
    int n;
    cin>>n;
    int flag,num,ans;
    for (int i=1;i<=n;i++)
    {
        ans=0;
        scanf("%d %d",&flag,&num);
        if (flag==1) insert(root,num);
        else if (flag==2) remove(root,num);
        else if (flag==3) ans=getrank(root,num),printf("%d\n",ans);
        else if (flag==4) ans=find_kth(root,num),printf("%d\n",ans);
        else if (flag==5) ans=lower(root,num),printf("%d\n",ans);
        else if (flag==6) ans=uper(root,num),printf("%d\n",ans);
    }
    return 0; 
}

猜你喜欢

转载自blog.csdn.net/acerandaker/article/details/80159685