第一次学,调了两天
一个操作一个操作的讲
插入,这个没啥好说的,将一个点插入树中,定义一个insert(int &o,int v)函数,o传的引用,改变o值时,root也会跟着改变,先判断当前传入值是否存在值,即当前o是否等于0,等于0则没有用过,初始化一个prio=rand(),即给当前插入点随机附一个优先级,防止出题人对着你卡数据把你卡成一条链,用prio来旋转,初始当前点的size为1;
push_up操作,更改当前修改点的size,即子树大小,当前点的子树大小等于左儿子子树大小+右儿子子树大小+1
rotate旋转,向右旋转则把当前点接在他的左儿子的右子树上,来确保中序遍历不变,中序遍历的意思是优先走左子树走到叶子节点后开始回溯输出,输出结果一定是升序,自己手摸,向左旋转同理,转完之后更新两个节点的size
remove删除操作,先找到值为当前删除点的点的标号,然后先看当前点的左右子树哪个为空,然后就把他和另一个子树交换位置,直到当前点成为叶子节点,然后删掉它,如果都不为空,判断哪科子树优先级更小,就与那棵子树进行交换,直到换到叶子节点,过程中注意要保证中序遍历,随时更新size
大于等于的最小值,和线段树相似,随机从一个点进入,然后比较他的val与当前查询值得大小,小于则搜右子树,反之则搜索左子树然后和当前点取min
小于等于的最大值,与上一个操作相似,只不过把第二种情况换成取max
获取当前值的排名,随机进入,比较大小,小的话,则说明在当前点的前面,搜索右子树,加上左子树大小还有当前点的1,否则搜索右子树
获取排名为k的值,说白了求第k大,与上面实现相反,判断子树大小与排名的关系,bulabula,嗯,
上代码
//By Acer.Mo #include<cmath> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> using namespace std; const int MAX_N = 3e5 + 10; const int inf = 0x3f3f3f3f; struct node { int ch[2]; int val, size, prio; }; node t[MAX_N]; int pool_cur; int del_pool[MAX_N], del_cur; int root; void init() { pool_cur = 1; root = 0; } void push_up(int o) { // cout<<"ooo="<<o<<" lss="<<t[t[o].ch[0]].size<<" rss="<< t[t[o].ch[1]].size <<endl; t[o].size = t[t[o].ch[0]].size + t[t[o].ch[1]].size + 1; } void rotate(int &o, int d) { int u = t[o].ch[d]; t[o].ch[d] = t[u].ch[d ^ 1]; t[u].ch[d ^ 1] = o; //t[u].size = t[o].size; push_up(o); push_up(u); o = u; } void insert(int &o, int v) { if (!o) { o = pool_cur++; t[o].size=1; t[o].prio=rand(); t[o].val = v; return ; } t[o].size++; int d = t[o].val < v; insert(t[o].ch[d], v); push_up(o); if (t[t[o].ch[d]].prio < t[o].prio) rotate(o, d); return ; } void remove(int &o, int v) { if (!o) return ; if (t[o].val == v) { int u = o; if (t[o].ch[0]*t[o].ch[1]==0) { o=t[o].ch[0]+t[o].ch[1]; return ; } else { int d = (t[t[o].ch[0]].prio<t[t[o].ch[1]].prio); rotate(o, d ^ 1); remove(t[o].ch[d], v); } } else { int d =(t[o].val < v); remove(t[o].ch[d], v); } push_up(o); return ; } int lower(int o,int v) { if (o==0) return -1<<30; if (t[o].val<v) return max(t[o].val,lower(t[o].ch[1],v)); else return lower(t[o].ch[0],v); } int uper(int o,int v) { if (!o) return inf; if (t[o].val>v) return min(t[o].val,uper(t[o].ch[0],v)); else return uper(t[o].ch[1],v); } int getrank(int o,int v) { if (!o) return 1; if (t[o].val>=v) return getrank(t[o].ch[0],v); return getrank(t[o].ch[1],v)+1+t[t[o].ch[0]].size; } int find_kth(int o, int k) { //cout<<"ooo="<<o<<" kkk="<<k<<endl; if (!o) return 0; int d = k - t[t[o].ch[0]].size;//cout<<"size="<<t[t[o].ch[0]].size<<endl; if (d <= 0) return find_kth(t[o].ch[0], k); if (d == 1) return t[o].val; else return find_kth(t[o].ch[1], d - 1); } int main() { init(); int n; cin>>n; int flag,num,ans; for (int i=1;i<=n;i++) { ans=0; scanf("%d %d",&flag,&num); if (flag==1) insert(root,num); else if (flag==2) remove(root,num); else if (flag==3) ans=getrank(root,num),printf("%d\n",ans); else if (flag==4) ans=find_kth(root,num),printf("%d\n",ans); else if (flag==5) ans=lower(root,num),printf("%d\n",ans); else if (flag==6) ans=uper(root,num),printf("%d\n",ans); } return 0; }