讲解博客:https://www.luogu.com.cn/blog/user19027/solution-p3369
模板博客:https://blog.csdn.net/clove_unique/article/details/50630280
补充一句:找前驱和后驱时pre和next函数只能找严格大于或小于的值,如果要找大于等于或小于等于的值,可以根据cnt另写一个函数,并且在pre和next之前需要insert一下,将需要查询的点旋转到根节点上
代码:
class Splay
{
public:
int ch[N][2],f[N],size[N],cnt[N],key[N];//size:子树大小,cnt:当前节点出现次数,key:权值
int sz,root;
inline void clear(int x)
{
ch[x][0]=ch[x][1]=f[x]=size[x]=cnt[x]=key[x]=0;
}
inline bool get(int x)
{
return ch[f[x]][1]==x;
}
inline void update(int x)
{
if(x)
{
size[x]=cnt[x];
if(ch[x][0])
size[x]+=size[ch[x][0]];
if(ch[x][1])
size[x]+=size[ch[x][1]];
}
}
inline void rotate(int x)
{
int old=f[x],oldf=f[old],whichx=get(x);
ch[old][whichx]=ch[x][whichx^1];
f[ch[old][whichx]]=old;
ch[x][whichx^1]=old;
f[old]=x;
f[x]=oldf;
if(oldf)
ch[oldf][ch[oldf][1]==old]=x;
update(old);
update(x);
}
inline void splay(int x)
{
for(int fa;fa=f[x];rotate(x))
if(f[fa])
rotate((get(x)==get(fa))?fa:x);
root=x;
}
inline void insert(int x)
{
if(root==0)
{
sz++;
ch[sz][0]=ch[sz][1]=f[sz]=0;
root=sz;
size[sz]=cnt[sz]=1;
key[sz]=x;
return;
}
int now=root,fa=0;
while(1)
{
if(x==key[now])
{
cnt[now]++;
update(now);
update(fa);
splay(now);
break;
}
fa=now;
now=ch[now][key[now]<x];
if(now==0)
{
sz++;
ch[sz][0]=ch[sz][1]=0;
f[sz]=fa;
size[sz]=cnt[sz]=1;
ch[fa][key[fa]<x]=sz;
key[sz]=x;
update(fa);
splay(sz);
break;
}
}
}
inline int find(int x)//查询x的排名
{
int now=root,ans=0;
while(1)
{
if(x<key[now])
now=ch[now][0];
else
{
ans+=(ch[now][0]?size[ch[now][0]]:0);
if(x==key[now])
{
splay(now);
return ans+1;
}
ans+=cnt[now];
now=ch[now][1];
}
}
}
inline int findx(int x)//找到排名为x的点
{
int now=root;
while(1)
{
if(ch[now][0]&&x<=size[ch[now][0]])
now=ch[now][0];
else
{
int temp=(ch[now][0]?size[ch[now][0]]:0)+cnt[now];
if(x<=temp)
return key[now];
x-=temp;
now=ch[now][1];
}
}
}
inline int pre()//小于某个数的最大值
{
int now=ch[root][0];
while(ch[now][1])
now=ch[now][1];
return now;
}
inline int next()//大于某个数的最小值
{
int now=ch[root][1];
while(ch[now][0])
now=ch[now][0];
return now;
}
inline void del(int x)
{
int whatever=find(x);
if(cnt[root]>1)
{
cnt[root]--;
update(root);
return;
}
if(!ch[root][0]&&!ch[root][1])
{
clear(root);
root=0;
return;
}
if(!ch[root][0])
{
int oldroot=root;
root=ch[root][1];
f[root]=0;
clear(oldroot);
return;
}
else if(!ch[root][1])
{
int oldroot=root;
root=ch[root][0];
f[root]=0;
clear(oldroot);
return;
}
int leftbig=pre(),oldroot=root;
splay(leftbig);
ch[root][1]=ch[oldroot][1];
f[ch[oldroot][1]]=root;
clear(oldroot);
update(root);
}
}tree;