版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
Splay:
这是从洛谷日报第62期学的。
#include <bits/stdc++.h>
#define maxn 500005
using namespace std;
int n,sz,root,fa[maxn],val[maxn],ch[maxn][2],siz[maxn],cnt[maxn];
inline bool isc(int x){return ch[fa[x]][1]==x;}
inline void upd(int x){siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x];}
void rot(int x){
int y=fa[x],z=fa[y],c=isc(x);
if(z) ch[z][isc(y)]=x;
(ch[y][c]=ch[x][!c])&&(fa[ch[y][c]]=y);
fa[ch[x][!c]=y]=x,fa[x]=z;
upd(y),upd(x);
}
void splay(int x,int goal=0){
for(int y;(y=fa[x])!=goal;rot(x))
if(fa[y]!=goal) rot(isc(y)==isc(x)?y:x);
if(!goal) root=x;
}
void find(int x){//not exist will return neighbor
int cur=root;
while(ch[cur][x>val[cur]]&&val[cur]!=x) cur=ch[cur][x>val[cur]];
splay(cur);
}
int pre(int x){
find(x);
if(val[root]<x) return root;
int cur=ch[root][0];
while(ch[cur][1]) cur=ch[cur][1];
return cur;
}
int nxt(int x){
find(x);
if(val[root]>x) return root;
int cur=ch[root][1];
while(ch[cur][0]) cur=ch[cur][0];
return cur;
}
void insert(int x){
int cur=root,p=0;
while(cur&&val[cur]!=x) p=cur,cur=ch[cur][x>val[cur]];
if(cur) cnt[cur]++;
else{
fa[cur=++sz]=p,val[cur]=x,siz[cur]=cnt[cur]=1;
if(p) ch[p][x>val[p]]=cur;
}
splay(cur);
}
void del(int x){
int p=pre(x),t=nxt(x);
splay(p),splay(t,p);
if(cnt[p=ch[t][0]]>1) cnt[p]--,splay(p);//修改之后必须要splay!!
else ch[t][0]=0;
}
int rnk(int k){
int cur=root;
while(1){
if(k<=siz[ch[cur][0]]) cur=ch[cur][0];
else if(k>siz[ch[cur][0]]+cnt[cur]) k-=siz[ch[cur][0]]+cnt[cur],cur=ch[cur][1];
else return cur;
}
}
int main()
{
scanf("%d",&n);
insert(-1e9),insert(1e9);
for(int i=1,op,x;i<=n;i++){
scanf("%d%d",&op,&x);
switch(op){
case 1:insert(x);break;
case 2:del(x);break;
case 3:find(x);printf("%d\n",siz[ch[root][0]]);break;
case 4:printf("%d\n",val[rnk(x+1)]);break;
case 5:printf("%d\n",val[pre(x)]);break;
case 6:printf("%d\n",val[nxt(x)]);break;
}
}
}
无旋Treap:
和splay比起来短的多得多得多得多。。。
从这篇博客学的。
#include<bits/stdc++.h>
#define maxn 100005
using namespace std;
int n,lc[maxn],rc[maxn],v[maxn],siz[maxn],rnd[maxn],tot,rt;
inline int Newnode(int x){
siz[++tot]=1,lc[tot]=rc[tot]=0,v[tot]=x,rnd[tot]=rand()<<15|rand();
return tot;
}
inline void upd(int x){siz[x]=siz[lc[x]]+siz[rc[x]]+1;}
void merge(int &now,int a,int b){
if(!a||!b) {now=a+b;return;}
if(rnd[a]<rnd[b]) now=a,merge(rc[now],rc[a],b);
else now=b,merge(lc[now],a,lc[b]);
upd(now);
}
void split(int now,int &a,int &b,int val){
if(!now) {a=b=0;return;}
if(v[now]<=val) a=now,split(rc[now],rc[a],b,val);
else b=now,split(lc[now],a,lc[b],val);
upd(now);
}
int find(int x,int k){
while(siz[lc[x]]+1!=k){
if(k<=siz[lc[x]]) x=lc[x];
else k-=siz[lc[x]]+1,x=rc[x];
}
return v[x];
}
int main()
{
scanf("%d",&n);
int op,x,a,b,c;
while(n--){
scanf("%d%d",&op,&x);
if(op==1) split(rt,a,b,x),c=Newnode(x),merge(a,a,c),merge(rt,a,b);
if(op==2) split(rt,a,b,x),split(a,a,c,x-1),merge(c,lc[c],rc[c]),merge(a,a,c),merge(rt,a,b);
if(op==3) split(rt,a,b,x-1),printf("%d\n",siz[a]+1),merge(rt,a,b);
if(op==4) printf("%d\n",find(rt,x));
if(op==5) split(rt,a,b,x-1),printf("%d\n",find(a,siz[a])),merge(rt,a,b);
if(op==6) split(rt,a,b,x),printf("%d\n",find(b,1)),merge(rt,a,b);
}
}