[学习笔记] Splay

为了学 link-cut-tree \text{link-cut-tree} 才讲的 Splay \text{Splay} ,之前已经学过无旋 treap \text{treap} 了,因为本质上都是对二叉搜索树的优化,理解起来可能会更容易吧,下面就以这一道例题:普通平衡树,来讲解一下 Splay \text{Splay} 的基本操作。

数组定义

  • c h [ x ] [ 0 / 1 ] ch[x][0/1] ,表示 x x 的左儿子或者右儿子。
  • v a l [ x ] val[x] ,表示 x x 点的键值。
  • c n t [ x ] cnt[x] ,表示 x x 该点的出现次数。
  • p a r [ x ] par[x] ,表示 x x 的父亲。
  • s i z [ x ] siz[x] ,表示 x x 为根的子树的大小。

具体操作

chk 操作
辅助操作,找 x x 是它父亲的左儿子还是右儿子。

int chk(int x)
{
	return ch[par[x]][1]==x;
}

push_up 操作
辅助操作,用左儿子和右儿子更新一下 s i z siz 数组。

void push_up(int x)
{
	siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x];
}

rotate 操作
Splay \text{Splay} 的核心操作,旋转 x x ,先看一个例子:
在这里插入图片描述
其中,一种较为优秀的转法是这样的:
在这里插入图片描述
多模拟几次,我们总结一下,假设它的父亲是 y y y y 的父亲是 z z ,我们先找出 x x y y 的 左儿子 / / 右儿子,记它为 k k ,我们把 c h [ y ] [ k ] ch[y][k] 替换成 c h [ x ] [ ! k ] ch[x][!k] ,把 c h [ z ] [ c h k ( y ) ] ch[z][chk(y)] 替换成 x x c h [ x ] [ ! k ] ch[x][!k] 替换成 y y ,然后再更新 x x y y ,代码如下:

void rotate(int x)
{
	int y=par[x],z=par[y],k=chk(x),w=ch[x][k^1];
	ch[y][k]=w;par[w]=y;
	ch[z][chk(y)]=x;par[x]=z;
	ch[x][k^1]=y;par[y]=x;
	push_up(y);push_up(x);
}

Splay 操作
核心操作,把点 x x 旋到 y y 的子节点处,这里我们使用双选,如果 x , y , z x,y,z 三点共线,那么我们先旋转 y y ,再旋转 x x ;否则我们旋转两次 x x ,这样旋转出来的树形态更优,代码如下:

void splay(int x,int goal=0)
{
	while(par[x]!=goal)
	{
		int y=par[x],z=par[y];
		if(z!=goal)
		{
			if(chk(x)==chk(y)) rotate(y);
			else rotate(x);
		}
		rotate(x);
	}
	if(!goal) rt=x;
}

find 操作
把第一个值小于等于 x x 点旋转到根,我们先用二叉查找树的方法找到它,然后直接 Splay \text{Splay} 它到根。

void find(int x)
{
	if(!rt) return ;
	int cur=rt;
	while(ch[cur][x>val[cur]] && x!=val[cur])
		cur=ch[cur][x>val[cur]];
	splay(cur);
}

insert 操作
插入 x x 这个值,我们先查找这个值,如果找到了,把次数 + 1 +1 ,否则我们新建一个节点,然后把这个节点旋转到根(随机化树形态)。

void insert(int x)
{
	int cur=rt,p=0;
	while(cur && val[cur]!=x)
	{
		p=cur;
		cur=ch[cur][x>val[cur]];
	}
	if(cur) cnt[cur]++;
	else
	{
		cur=++ncnt;
		if(p) ch[p][x>val[p]]=cur;
		par[cur]=p;ch[cur][0]=ch[cur][1]=0;
		cnt[cur]=siz[cur]=1;val[cur]=x;
	}
	splay(cur);
}

kth 操作
找到第 k k 大的值,用普通平衡树的方法,先找左子树的点数够不够,否则看左子树+当前点数够不够,足够则第 k k 大就是当前点,否则去找右子树,具体实现如下:

int kth(int k)
{
	int cur=rt;
	while(1)
	{
		if(ch[cur][0] && k<=siz[ch[cur][0]])
			cur=ch[cur][0];
		else if(k>siz[ch[cur][0]]+cnt[cur])
		{
			k-=siz[ch[cur][0]]+cnt[cur];
			cur=ch[cur][1];
		}
		else return cur;
	}
}

pre / / suc 操作
f i n d ( x ) find(x) 把小于等于 x x 的值旋转到根,如果根是 x x ,那么找 左子树最底下的右儿子 / / 右子树最底下的左儿子,否则答案就是根。

int pre(int x)
{
	find(x);
	if(val[rt]<x) return rt;
	int cur=ch[rt][0];
	while(ch[cur][1]) cur=ch[cur][1];
	return cur;
}
int suc(int x)
{
	find(x);
	if(val[rt]>x) return rt;
	int cur=ch[rt][1];
	while(ch[cur][0]) cur=ch[cur][0];
	return cur;
}

remove 操作
删除 x x ,找到 x x 的前驱和后继,把前驱旋转到根,后继旋转到前驱,所以 x x 一定是后继的左儿子,且 x x 的子树为空,所以可以直接删除 x x ,具体实现如下:

void remove(int x)
{
	int lst=pre(x),nxt=suc(x);
	splay(lst);splay(nxt,lst);
	int now=ch[nxt][0];
	if(cnt[now]>1)
	{
		cnt[now]--;
		splay(now);
	}
	else ch[nxt][0]=0;
}

至此 Splay \text{Splay} 的基本操作就讲完了,下面贴个完整代码吧,更多操作还是慢慢学吧。

#include <cstdio>
const int M = 200005;
int read()
{
	int x=0,flag=1;char c;
	while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
	while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
	return x*flag;
}
int n,rt,ncnt,ch[M][2],val[M],cnt[M],par[M],siz[M];
int chk(int x)
{
	return ch[par[x]][1]==x;
}
void push_up(int x)
{
	siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x];
}
void rotate(int x)
{
	int y=par[x],z=par[y],k=chk(x),w=ch[x][k^1];
	ch[y][k]=w;par[w]=y;
	ch[z][chk(y)]=x;par[x]=z;
	ch[x][k^1]=y;par[y]=x;
	push_up(y);push_up(x);
}
void splay(int x,int goal=0)
{
	while(par[x]!=goal)
	{
		int y=par[x],z=par[y];
		if(z!=goal)
		{
			if(chk(x)==chk(y)) rotate(y);
			else rotate(x);
		}
		rotate(x);
	}
	if(!goal) rt=x;
}
void find(int x)
{
	if(!rt) return ;
	int cur=rt;
	while(ch[cur][x>val[cur]] && x!=val[cur])
		cur=ch[cur][x>val[cur]];
	splay(cur);
}
void insert(int x)
{
	int cur=rt,p=0;
	while(cur && val[cur]!=x)
	{
		p=cur;
		cur=ch[cur][x>val[cur]];
	}
	if(cur) cnt[cur]++;
	else
	{
		cur=++ncnt;
		if(p) ch[p][x>val[p]]=cur;
		par[cur]=p;ch[cur][0]=ch[cur][1]=0;
		cnt[cur]=siz[cur]=1;val[cur]=x;
	}
	splay(cur);
}
int kth(int k)
{
	int cur=rt;
	while(1)
	{
		if(ch[cur][0] && k<=siz[ch[cur][0]])
			cur=ch[cur][0];
		else if(k>siz[ch[cur][0]]+cnt[cur])
		{
			k-=siz[ch[cur][0]]+cnt[cur];
			cur=ch[cur][1];
		}
		else return cur;
	}
}
int pre(int x)
{
	find(x);
	if(val[rt]<x) return rt;
	int cur=ch[rt][0];
	while(ch[cur][1]) cur=ch[cur][1];
	return cur;
}
int suc(int x)
{
	find(x);
	if(val[rt]>x) return rt;
	int cur=ch[rt][1];
	while(ch[cur][0]) cur=ch[cur][0];
	return cur;
}
void remove(int x)
{
	int lst=pre(x),nxt=suc(x);
	splay(lst);splay(nxt,lst);
	int now=ch[nxt][0];
	if(cnt[now]>1)
	{
		cnt[now]--;
		splay(now);
	}
	else ch[nxt][0]=0;
}
int main()
{
	n=read();
	insert(0x3f3f3f3f);
	insert(0xcfcfcfcf);
	for(int i=1;i<=n;i++)
	{
		int op=read(),x=read();
		if(op==1) insert(x);
		if(op==2) remove(x);
		if(op==3)
		{
			find(x);
			printf("%d\n",siz[ch[rt][0]]);
		}
		if(op==4) printf("%d\n",val[kth(x+1)]);
		if(op==5) printf("%d\n",val[pre(x)]);
		if(op==6) printf("%d\n",val[suc(x)]);
	}
}

例题

第一道题:序列终结者 splay \text{splay} 打标记入门题
第二道题:SuperMemo,这道题相对于上一道题多了一个 Revolve \text{Revolve} 操作,直接把 [ l , r k ] [l,r-k] 这个区间拆出来,把它当作点,直接插回原序列,下面贴上我的代码:

#include <cstdio>
#include <iostream>
using namespace std;
#define inf 0x3f3f3f3f
const int M = 200005;
int read()
{
	int x=0,flag=1;char c;
	while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
	while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
	return x*flag;
}
int n,m,rt,ncnt,ch[M][2],val[M],Min[M],par[M],siz[M],fl[M],la[M];
char s[10];
int chk(int x)
{
	return ch[par[x]][1]==x;
}
void push_up(int x)//上传 
{
	if(!x) return ;
	siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1;
	Min[x]=min(min(Min[ch[x][0]],Min[ch[x][1]]),val[x]);
}
void flip(int x)//翻转 
{
	if(!x) return ;
	swap(ch[x][0],ch[x][1]);
	fl[x]^=1;
}
void add(int x,int c)//加法 
{
	if(!x) return ;
	Min[x]+=c;val[x]+=c;
	la[x]+=c;
}
void push_down(int x)//下传标记 
{
	if(fl[x])
	{
		flip(ch[x][0]);flip(ch[x][1]);
		fl[x]=0;
	}
	if(la[x])
	{
		add(ch[x][0],la[x]);add(ch[x][1],la[x]);
		la[x]=0;
	}
}
void rotate(int x)//旋转 
{
	int y=par[x],z=par[y],k=chk(x),w=ch[x][k^1];
	push_down(y);push_down(x);
	ch[y][k]=w;par[w]=y;
	ch[z][chk(y)]=x;par[x]=z;
	ch[x][k^1]=y;par[y]=x;
	push_up(y);push_up(x);
}
void splay(int x,int goal=0)//把x旋转到goal 
{
	while(par[x]^goal)
	{
		int y=par[x],z=par[y];
		if(z!=goal)
		{
			if(chk(x)==chk(y)) rotate(y);
			else rotate(x);
		}
		rotate(x);
	}
	if(!goal) rt=x; 
}
int find(int k)//排名为k的点 
{
	int cur=rt;
	while(1)
	{
		push_down(cur);
		if(ch[cur][0] && k<=siz[ch[cur][0]])
			cur=ch[cur][0];
		else if(k>siz[ch[cur][0]]+1)
		{
			k-=siz[ch[cur][0]]+1;
			cur=ch[cur][1];
		}
		else return cur;
	}
}
void print(int x)
{
	if(!x) return ;
	push_down(x);
	print(ch[x][0]);
	printf("%d ",val[x]);
	print(ch[x][1]);
}
void ins(int x,int k)//把x插入k位后 
{
	int a=find(k),b=find(k+1);
	splay(a);splay(b,a);
	ch[b][0]=x;par[x]=b;
	push_up(b);
}
void del(int k)//删除k位
{
	int a=find(k-1),b=find(k+1);
	splay(a);splay(b,a);
	ch[b][0]=0;
	push_up(b);
}
int main()
{
	n=read();
	Min[0]=inf;ncnt=2;
	rt=1;siz[1]=siz[2]=1;
	ch[1][1]=2;par[2]=1;//加入哨兵 
	for(int i=1;i<=n;i++)
	{
		siz[++ncnt]=1;val[ncnt]=Min[ncnt]=read();
		ins(ncnt,ncnt-2);
	}
	m=read();
	while(m--)
	{
		scanf("%s",s);
		if(s[0]=='D')
		{
			del(read()+1);//删除(要考虑哨兵) 
			continue ;
		}
		int l=read(),r=read();
		if(s[0]=='A')//区间加 
		{
			int a=find(l),b=find(r+2);
			splay(a);splay(b,a);
			add(ch[b][0],read());
		}
		if(s[0]=='R' && s[3]=='E')//翻转 
		{
			int a=find(l),b=find(r+2);
			splay(a);splay(b,a);
			flip(ch[b][0]);
		}
		if(s[0]=='R' && s[3]=='O')
		{
			int k=read();
			k=(k%(r-l+1)+(r-l+1))%(r-l+1);
			//[l,r-k]
			int a=find(l),b=find(r-k+2);//拆区间 
			splay(a);splay(b,a);
			int t=ch[b][0];
			ch[b][0]=0;par[t]=0;
			ins(t,l+k);//重新插入 
		}
		if(s[0]=='I')//插入 
		{
			siz[++ncnt]=1;
			val[ncnt]=Min[ncnt]=r;
			ins(ncnt,l+1);
		}
		if(s[0]=='M')//查询最小值 
		{
			int a=find(l),b=find(r+2);
			splay(a);splay(b,a);
			printf("%d\n",Min[ch[b][0]]);
		}
	}
}
发布了192 篇原创文章 · 获赞 12 · 访问量 3337

猜你喜欢

转载自blog.csdn.net/C202044zxy/article/details/103893878