这个名字取得比较玄乎,一眼看上去并不知道有什么卵用,但是, 如果你是刚学平衡树的新手,那么从替罪羊树开始学一定是个绝佳的选择,因为它是个很优雅的平衡树,什么叫优雅?暴力即是优雅!
如果在一棵平衡的二叉搜索树内进行查询等操作,时间就可以稳定在log(n),但是每一次的插入节点和删除节点,都可能会使得这棵树不平衡,最坏情况就是退化成一条链,显然我们不想要这种树,于是各种维护的方法出现了,大部分的平衡树都是通过旋转来维护平衡的,但替罪羊树就很厉害了,一旦发现不平衡的子树,立马拍扁重建,这就是替罪羊树的核心:暴力重建
先来说说我的替罪羊树上的每个节点包含些什么:
- zuo,you:记录该节点的左右儿子
- x:该节点的值
- tot:有多少个值为x的数
- size,trsize,whsize:size表示以该节点为根的子树内有多少个节点,trsize表示有多少个有效节点(这个后面再讲啦),whsize表示有多少个数(也就是子树内所有节点的tot的和)
- fa:该点的父亲
- tf:该点是否有删除标记(这个也会在后面讲啦)
那以洛谷上的模板题【模板】普通平衡树来作为例子吧!
我们分开讨论每一种操作:
操作1——加点:
先找到一个特殊的节点,如果那个节点的值等于要加的那个点,那么直接让那个节点的tot+1即可,否则如果比那个节点的值要小,就让新加的节点成为它的左儿子,不然就是右儿子。
那么怎么找那个“特殊的节点”呢?假如我以x为关键字去查找,先从根节点开始,假如x比根节点的值要小,那我就去它的左儿子那里,否则去右儿子,直到满足这两个条件中的一个:找到了值为x的节点或不能继续往下走。
那为什么要找这个特殊的节点呢?因为如果这个特殊的节点的值等于要加的点,那么直接
即可,否则让新加的点成为这个节点的儿子,并且这一定是最佳的选择,为什么呢?仔细想想,其实这个找特殊点其实就是找值为x的点,没有的话就找到最深的最接近的点。
那找点和加点的代码如下:
找点:
int find(int x,int now)//now表示当前找到哪个点
{
if(x<tree[now].x&&tree[now].zuo)return find(x,tree[now].zuo);//比当前点的值要小并且有左儿子
if(x>tree[now].x&&tree[now].you)return find(x,tree[now].you);
return now;
}
加点:
void add(int x)
{
if(root==0)//假如当前没有根节点,也就是当前的树是空的,那么直接让他成为根
{
build(x,root=kk(),0);//新建节点(后面有讲)
return;
}
int p=find(x,root);//找到特殊点
if(x==tree[p].x)
{
tree[p].tot++;
if(tree[p].tf)tree[p].tf=false,updata(p,1,0,1);
else updata(p,0,0,1);
}
else if(x<tree[p].x)build(x,tree[p].zuo=kk(),p),updata(p,1,1,1);
else build(x,tree[p].you=kk(),p),updata(p,1,1,1);
find_rebuild(root,x);
}//没讲到的先别管啦
然后再加上里面用到的几个函数:
新建节点:
void build(int x,int y,int fa)//初始化树上编号为y的节点,它的值为x,父亲为fa
{
tree[y].zuo=tree[y].you=0;tree[y].fa=fa;tree[y].tf=false;
tree[y].x=x;tree[y].tot=tree[y].size=tree[y].trsize=tree[y].whsize=1;
}
函数,更新父亲以及爷爷以及祖先们的 , 还有 update 啦):
void updata(int x,int y,int z,int k)
{
if(!x)return;//假如到头了就停止
tree[x].trsize+=y;
tree[x].size +=z;//对齐qwq
tree[x].whsize+=k;
updata(tree[x].fa,y,z,k);
}
然后里面的kk()函数就先留到后面,现在讲会牵扯到好多东西的(虽然代码极短),就先记住它的用处吧:新建一个节点,假如之前有废掉的节点那么就直接用,否则加多一个节点。
操作2——删点:
删点,严格来说是删掉一个数,假如我要删一个值为x的数,那就先找到值为x的节点,然后 。
没啦?当然不是,假如 之后 变成 了怎么办?这意味着这个节点不存在了,然后我们删掉它?假如把它删了,那它的左右儿子何去何从?所以我们不能动它,给它打个标记,标记这货被删除了,然后就行了。
代码在此:
void del(int x)
{
int p=find(x,root);
tree[p].tot--;
if(!tree[p].tot)tree[p].tf=true,updata(p,-1,0,-1);
else updata(p,0,0,-1);
find_rebuild(root,x);
}
各位肯定敏锐的发现了这两个函数里都用到了一个函数,find_rebuild,回顾上面的内容,我提到过,每一次的加点和删点都有可能使这棵树不平衡,假如有一棵子树不平衡,我们就需要将其重建,所以,find_rebuild就是用来查找需要重建的子树。
先说一下怎么重建吧。
因为需要重建的子树必定是二叉搜索树,那么这棵子树的中序遍历一定是一个严格上升的序列,于是我们就先中序遍历一下,把树上的有效节点放到一个数组里面,注意无效节点(无效节点也就是被打了删除标记的点)不要,毕竟它名存实亡。
然后我们再把数组中的节点重建成一棵及其平衡的完全二叉树(按完全二叉树的方法来建,但因为节点数的原因,不一定是一棵完全二叉树),具体方法就是每一次选取数组中间的节点,让它成为根,左边的当他左儿子,右边的当他右儿子,因为左边的都比他小,右边的都比他大,所以建出来的依然是一棵二叉搜索树。然后再对它的左右儿子进行相同操作即可。
然后我们再讲一讲怎么找需要重建的子树。
我们设每一次
或
的数为
,在将这个数加入到树中或从树中删除之后,假如在树中值为
的节点是
,那我们考虑到其实每一次可能需要重构的子树只会是以 根到
路径上的节点 为根的子树,那么我们就可以从根往
走一次,看看谁要重建就好了。还有个问题,为啥不从
往根走呢?打个比方,假如根到
路径上有两个点,
和
,并且
的祖先节点,然后特别巧的发现
都是需要重建的,那么,这时候我们只需要重建以
为根的子树,因为重建完之后,以b为根的子树其实也重建完了,但要是从
往根走呢?那么先会重建b,然后到a的时候还是要重建,那显然没有直接重建
要好,所以要从根往y走。
最后一个小重点,怎么判断一棵替罪羊树是否平衡呢?(判断的方法不唯一,只要保持平衡即可,这里只是给出本人的做法)
在替罪羊树中,定义了一个平衡因子α,α的范围因题而异,一般取0.5~1.0之间,若题目没有特殊说明,一般就取个中0.75就好了。那这个α有啥用呢?替罪羊树判断一棵子树是否平衡的方法是:如果 x的左(右)子树的节点数量 > 以x为根的子树的节点数量α ,那么,以x为根的这棵子树就是不平衡的。显然的,如果有一棵子树的大小超过了 以x为根的子树的节点数量α,那么这种节点一边倒的情况对于查询来说肯定就很慢,所以,这个时候我们就将它重建。
还有一种情况,我们提到过,替罪羊树的删除只是打个标记,那么,我们在查询的时候还是有可能经过打了删除标记的节点的,假如有删除标记的节点多了,那效率自然就会变得特别低,所以,我们需要再判断一下,假如在一棵子树中,有超过30%的点被删除了,那么就把这棵树重建。
find_rebuild代码如下:
void find_rebuild(int now,int x)//now表示现在走到的节点,x表示要一直走到值为x的点
{
if((double)tree[tree[now].zuo].size>(double)tree[now].size*alpha||
(double)tree[tree[now].you].size>(double)tree[now].size*alpha||
(double)tree[now].size-(double)tree[now].trsize>(double)tree[now].size*0.3){rebuild(now);return;}
if(tree[now].x!=x)find_rebuild(x<tree[now].x?tree[now].zuo:tree[now].you,x);//继续向下搜索
}
rebuild代码如下:
void rebuild(int x)//重建以x为根的子树
{
tt=0;//数组下标
dfs_rebuild(x);//进行中序遍历并将有效节点压入数组
if(x==root)root=readd(1,tt,0);//x就是根,那么root就变成重建之后的那棵树的根
//readd用来把数组里的节点重新建成一棵完全二叉树,并返回这棵树的根
else
{
updata(tree[x].fa,-(tree[x].size-tree[x].trsize),0,0);//更新一下,因为被打了删除标记的节点即将不复存在,所以告诉x的祖先们要去掉原先被打了删除标记的节点
if(tree[tree[x].fa].zuo==x)tree[tree[x].fa].zuo=readd(1,tt,tree[x].fa);
else tree[tree[x].fa].you=readd(1,tt,tree[x].fa);
}
}
readd代码如下:
int readd(int l,int r,int fa)
{
if(l>r)return 0;//没有点了
int mid=(l+r)>>1;//选中间的点作为根
int id=kk();
tree[id].fa=fa;//更新各项
tree[id].tot=shulie[mid].tot;
tree[id].x=shulie[mid].x;
tree[id].zuo=readd(l,mid-1,id);
tree[id].you=readd(mid+1,r,id);
tree[id].whsize=tree[tree[id].zuo].whsize+tree[tree[id].you].whsize+shulie[mid].tot;
tree[id].size=tree[id].trsize=r-l+1;
tree[id].tf=false;
return id;//记得返回
}
还有中序遍历dfs_rebuild的代码:
void dfs_rebuild(int x)
{
if(x==0)return;
dfs_rebuild(tree[x].zuo);//先去左儿子
if(!tree[x].tf)shulie[++tt].x=tree[x].x,shulie[tt].tot=tree[x].tot;//假如没有删除标记,就只将他的x和tot加进数组,因为其他东西都没有用
ck[++t]=x;//仓库,存下废弃的节点
dfs_rebuild(tree[x].you);//再去右儿子
}
最后还有之前用了好多次的kk函数:
int kk()//短的可怕有木有
{
if(t>0)return ck[t--];//假如仓库内有点,就直接用
else return ++len;//否则再创造一个点
}
然后……就剩下几个基本操作啦!
操作3——查找x的排名:
我们只需要像find函数一样走一遍就好了,在走的时候,如果是往右儿子走,就让ans加上左子树的数的个数,再加上当前节点的tot,因为x一定比他们都大,否则就往左儿子走,当走到值为x的点时就结束。
代码也肯定很简单的啦!
void findxpm(int x)//这么简单应该也不用什么注释了吧(其实就是比较懒)
{
int now=root;
int ans=0;
while(tree[now].x!=x)
{
if(x<tree[now].x)now=tree[now].zuo;
else ans+=tree[tree[now].zuo].whsize+tree[now].tot,now=tree[now].you;
}
ans+=tree[tree[now].zuo].whsize;
printf("%d\n",ans+1);
}
操作4——查找排名为x的数:
类似的,先从根走起,假如当前节点的左子树的数的个数比x要小,那么让x减掉左子树的数的个数,然后在看一下当前节点的tot是否大于x,是的话答案就是这个节点了,否则让x减去它的tot,然后往右儿子那里跑,重复以上操作即可。
代码依然是那么简单。
void findpmx(int x)
{
int now=root;
while(1)
{
if(x<=tree[tree[now].zuo].whsize)now=tree[now].zuo;
else
{
x-=tree[tree[now].zuo].whsize;
if(x<=tree[now].tot)
{
printf("%d\n",tree[now].x);
return;
}
x-=tree[now].tot;
now=tree[now].you;
}
}
}
要注意!这两个函数里用的都是whsize!
操作5——查找x的前驱
显然的,x的左儿子的右儿子的右儿子的右儿子……(此处省略无数个右儿子)就是x的前驱,也就是比x小的数里面最大的数,就是x的前驱,但是,如果x没有左儿子怎么办?我们只能去找他的父亲,假如x是它父亲的右儿子,他的父亲就是他的前驱,但是,还有个问题,假如他父亲有删除标记,那么就不能当x的前驱,于是我们又可以找x的父亲左儿子的右儿子的右儿子……,假如他父亲也没有左儿子,那么继续往上走,重复以上操作;另一种情况,x是它父亲的左儿子,那么直接继续往上走就好了。
代码在此:
int pre(int now,int x,bool zy)//zy表示now是从左儿子来的还是右儿子来的,如果是右儿子,zy=true,否则为false,now表示当前节点,x表示我要找x的前驱
{
if(!zy)return pre(tree[now].fa,x,tree[tree[now].fa].zuo!=now);//假如从左儿子来
if(!tree[now].tf&&tree[now].x<x)return tree[now].x;//判断当前节点是否是x的前驱
if(tree[now].zuo)//否则往左儿子的右儿子的右儿子……走
{
now=tree[now].zuo;
while(tree[now].you)now=tree[now].you;
return tree[now].x;
}
return pre(tree[now].fa,x,tree[tree[now].fa].zuo!=now);//假如没有左儿子
}
操作6——查找x的后继
类似操作5,代码基本一样。。
int nxt(int now,int x,bool zy)//注释懒得写了qaq
{
if(!zy)return nxt(tree[now].fa,x,tree[tree[now].fa].you!=now);
if(!tree[now].tf&&tree[now].x>x)return tree[now].x;
if(tree[now].you)
{
now=tree[now].you;
while(tree[now].zuo)now=tree[now].zuo;
return tree[now].x;
}
return nxt(tree[now].fa,x,tree[tree[now].fa].you!=now);
}
终于,替罪羊树学完啦!接下来附上完整代码
#include <cstdio>
#include <cstdlib>
#include <cstring>
struct node{int zuo,you,x,tot,size,trsize,whsize,fa;bool tf;};
node tree[1000010];
int len=0,n,root=0;
int ck[1000010],t=0;
double alpha=0.75;
void build(int x,int y,int fa)
{
tree[y].zuo=tree[y].you=0;tree[y].fa=fa;tree[y].tf=false;
tree[y].x=x;tree[y].tot=tree[y].size=tree[y].trsize=tree[y].whsize=1;
}
inline int kk()
{
if(t>0)return ck[t--];
else return ++len;
}
void updata(int x,int y,int z,int k)
{
if(!x)return;
tree[x].trsize+=y;
tree[x].size +=z;
tree[x].whsize+=k;
updata(tree[x].fa,y,z,k);
}
int find(int x,int now)
{
if(x<tree[now].x&&tree[now].zuo)return find(x,tree[now].zuo);
if(x>tree[now].x&&tree[now].you)return find(x,tree[now].you);
return now;
}
struct sl{int x,tot;}shulie[1000010];
int tt;
void dfs_rebuild(int x)
{
if(x==0)return;
dfs_rebuild(tree[x].zuo);
if(!tree[x].tf)shulie[++tt].x=tree[x].x,shulie[tt].tot=tree[x].tot;
ck[++t]=x;
dfs_rebuild(tree[x].you);
}
int readd(int l,int r,int fa)
{
if(l>r)return 0;
int mid=(l+r)>>1;int id=kk();
tree[id].fa=fa;
tree[id].tot=shulie[mid].tot;
tree[id].x=shulie[mid].x;
tree[id].zuo=readd(l,mid-1,id);
tree[id].you=readd(mid+1,r,id);
tree[id].whsize=tree[tree[id].zuo].whsize+tree[tree[id].you].whsize+shulie[mid].tot;
tree[id].size=tree[id].trsize=r-l+1;
tree[id].tf=false;
return id;
}
void rebuild(int x)
{
tt=0;
dfs_rebuild(x);
if(x==root)root=readd(1,tt,0);
else
{
updata(tree[x].fa,-tree[x].size+tree[x].trsize,0,0);
if(tree[tree[x].fa].zuo==x)tree[tree[x].fa].zuo=readd(1,tt,tree[x].fa);
else tree[tree[x].fa].you=readd(1,tt,tree[x].fa);
}
}
void find_rebuild(int now,int x)
{
if((double)tree[tree[now].zuo].size>(double)tree[now].size*alpha||
(double)tree[tree[now].you].size>(double)tree[now].size*alpha||
(double)tree[now].size-(double)tree[now].trsize>(double)tree[now].size*0.4){rebuild(now);return;}
if(tree[now].x!=x)find_rebuild(x<tree[now].x?tree[now].zuo:tree[now].you,x);
}
void add(int x)
{
if(root==0)
{
build(x,root=kk(),0);
return;
}
int p=find(x,root);
if(x==tree[p].x)
{
tree[p].tot++;
if(tree[p].tf)tree[p].tf=false,updata(p,1,0,1);
else updata(p,0,0,1);
}
else if(x<tree[p].x)build(x,tree[p].zuo=kk(),p),updata(p,1,1,1);
else build(x,tree[p].you=kk(),p),updata(p,1,1,1);
find_rebuild(root,x);
}
void del(int x)
{
int p=find(x,root);
tree[p].tot--;
if(!tree[p].tot)tree[p].tf=true,updata(p,-1,0,-1);
else updata(p,0,0,-1);
find_rebuild(root,x);
}
void findxpm(int x)
{
int now=root;
int ans=0;
while(tree[now].x!=x)
{
if(x<tree[now].x)now=tree[now].zuo;
else ans+=tree[tree[now].zuo].whsize+tree[now].tot,now=tree[now].you;
}
ans+=tree[tree[now].zuo].whsize;
printf("%d\n",ans+1);
}
void findpmx(int x)
{
int now=root;
while(1)
{
if(x<=tree[tree[now].zuo].whsize)now=tree[now].zuo;
else
{
x-=tree[tree[now].zuo].whsize;
if(x<=tree[now].tot)
{
printf("%d\n",tree[now].x);
return;
}
x-=tree[now].tot;
now=tree[now].you;
}
}
}
int pre(int now,int x,bool zy)
{
if(!zy)return pre(tree[now].fa,x,tree[tree[now].fa].zuo!=now);
if(!tree[now].tf&&tree[now].x<x)return tree[now].x;
if(tree[now].zuo)
{
now=tree[now].zuo;
while(tree[now].you)now=tree[now].you;
return tree[now].x;
}
return pre(tree[now].fa,x,tree[tree[now].fa].zuo!=now);
}
int nxt(int now,int x,bool zy)
{
if(!zy)return nxt(tree[now].fa,x,tree[tree[now].fa].you!=now);
if(!tree[now].tf&&tree[now].x>x)return tree[now].x;
if(tree[now].you)
{
now=tree[now].you;
while(tree[now].zuo)now=tree[now].zuo;
return tree[now].x;
}
return nxt(tree[now].fa,x,tree[tree[now].fa].you!=now);
}
int main()
{
scanf("%d",&n);
while(n--)
{
int id,x;
scanf("%d %d",&id,&x);
if(id==1)add(x);
if(id==2)del(x);
if(id==3)findxpm(x);
if(id==4)findpmx(x);
if(id==5)printf("%d\n",pre(find(x,root),x,true));
if(id==6)printf("%d\n",nxt(find(x,root),x,true));
}
}