Tyvj P1728 普通平衡树 (平衡树)

题目链接

P1728 普通平衡树
时间: 1000ms / 空间: 131072KiB / Java类名: Main

背景

此为平衡树系列第一道:普通平衡树

描述

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)

输入格式

第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)

输出格式

对于操作3,4,5,6每行输出一个数,表示对应答案

测试样例1

输入

8
1 10
1 20
1 30
3 20
4 2
2 10
5 25
6 -1

输出

2
20
20
20

备注

n<=100000 所有数字均在-10^7到10^7内

题解:平衡树的基础操作。

Treap代码:

#include<algorithm>
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<math.h>
#include<stdlib.h>
#define nn 210000
#define eps 1e-8
#define inff 0x7fffffff
#define lson rt<<1,l,m
#define rson rt<<1|1,m+1,r
#define mod 20071027
using namespace std;
typedef long long LL;
typedef unsigned long long LLU;
struct multi_treap
{
    struct node
    {
        node *ch[2];
        int r,v;
        int cnt,sum;
        int cmp(int x)
        {
            if(v==x)
                return -1;
            return x<v?0:1;
        }
    };
    node *root;//treap的根节点
    int numof(node* o)//以o为根的树的结点数目
    {
        if(o==NULL)
            return 0;
        return o->sum;
    }
    void update(node* o)//更新以该节点为根的数的个数
    {
        if(o==NULL)
            return ;
        o->sum=numof(o->ch[0])+numof(o->ch[1])+o->cnt;
    }
    void Rotate(node* &o,int d)//d为0左旋,d为1右旋
    {
        node *k=o->ch[d^1];
        o->ch[d^1]=k->ch[d];
        k->ch[d]=o;
        update(o);
        update(k);
        o=k;
    }
    void Insert(node* &o,int x)//插入x
    {
        if(o==NULL)
        {
            o=new node();
            o->v=x,o->r=rand()*rand();
            o->cnt=o->sum=1;
            o->ch[0]=o->ch[1]=NULL;
            return ;
        }
        int d=o->cmp(x);
        if(d==-1)
        {
            o->cnt++;//注释掉这两句话就是关闭元素可重复的功能
            o->sum++;
            return ;
        }
        Insert(o->ch[d],x);
        if(o->ch[d]->r>o->r)
        {
            Rotate(o,d^1);
        }
        update(o);
    }
    void Remove(node* &o,int x)//删除x
    {
        if(o==NULL)
            return ;
        int d=o->cmp(x);
        if(d==-1)
        {
            if(o->cnt==1)
            {
                if(o->ch[0]==NULL)
                {
                    node *k=o;
                    o=o->ch[1];
                    delete k;
                }
                else if(o->ch[1]==NULL)
                {
                    node *k=o;
                    o=o->ch[0];
                    delete k;
                }
                else
                {
                    int d2=o->ch[0]->r>o->ch[1]->r?1:0;
                    Rotate(o,d2);
                    Remove(o->ch[d2],x);
                }
            }
            else
                o->cnt--;
        }
        else
            Remove(o->ch[d],x);
        update(o);
    }
    int Rank(node *o,int x)//查询x的排名
    {
        int ans=0;
        while(o!=NULL)
        {
            if(o->v>x)
            {
                o=o->ch[0];
            }
            else if(o->v<x)
            {
                ans+=numof(o->ch[0])+o->cnt;
                o=o->ch[1];
            }
            else
            {
                ans+=numof(o->ch[0])+1;
                break;
            }
        }
        return ans;
    }
    int Kth(node *o,int x)//查询排名为x的数
    {
        if(o==NULL)
            return -1;
        int ix=numof(o->ch[0]);
        if(ix>=x)
            return Kth(o->ch[0],x);
        ix=x-ix-o->cnt;
        if(ix<=0)
            return o->v;
        return Kth(o->ch[1],ix);
    }
    int Pre(node* o,int x)//查询x的前驱
    {
        int ans=-1;
        while(o!=NULL)
        {
            if(o->v<x)
            {
                ans=o->v;
                o=o->ch[1];
            }
            else
                o=o->ch[0];
        }
        return ans;
    }
    int Suc(node* o,int x)//查询x的后继
    {
        int ans=-1;
        while(o!=NULL)
        {
            if(o->v>x)
            {
                ans=o->v;
                o=o->ch[0];
            }
            else
                o=o->ch[1];
        }
        return ans;
    }
    void De(node* o)//清空
    {
        if(o==NULL)
            return ;
        De(o->ch[0]);
        De(o->ch[1]);
        delete o;
    }
}tp;
int main()
{
    int n;
    int x,y,i;
    while(scanf("%d",&n)!=EOF)
    {
        tp.root=NULL;
        for(i=1;i<=n;i++)
        {
            scanf("%d%d",&x,&y);
            if(x==1)
            {
                tp.Insert(tp.root,y);
            }
            else if(x==2)
            {
                tp.Remove(tp.root,y);
            }
            else if(x==3)
            {
                printf("%d\n",tp.Rank(tp.root,y));
            }
            else if(x==4)
            {
                printf("%d\n",tp.Kth(tp.root,y));
            }
            else if(x==5)
            {
                printf("%d\n",tp.Pre(tp.root,y));
            }
            else
                printf("%d\n",tp.Suc(tp.root,y));
        }
        tp.De(tp.root);
    }
    return 0;
}

Splay

#include<stdio.h>
#include<iostream>
#include<algorithm>
const int inf=0x3fffffff;
using namespace std;
struct node
{
    int val;
    int num;
    int sum;
    node* pre;
    node* ch[2];
}*root;
void display(node* o)
{
    if(o==NULL)
        return ;
    display(o->ch[0]);
    cout<<o->val<<endl;
    display(o->ch[1]);
}
int numof(node* o)
{
    if(o==NULL)
        return 0;
    return o->sum;
}
void update(node* o)
{
    if(o==NULL)
        return ;
    o->sum=o->num+numof(o->ch[0])+numof(o->ch[1]);
}
void Rotate(node* o)
{
    node* tem=o->pre;
    int d;
    if(tem->ch[0]==o) d=0;
    else d=1;
    tem->ch[d]=o->ch[d^1];
    if(o->ch[d^1]!=NULL)
        o->ch[d^1]->pre=tem;
    if(tem->pre!=NULL)
    {
        if(tem->pre->ch[0]==tem)
            tem->pre->ch[0]=o;
        else
            tem->pre->ch[1]=o;
    }
    o->pre=tem->pre;
    o->ch[d^1]=tem;
    tem->pre=o;
    update(tem);
    //update(o);伸展完以后再更新,减少常数
}
void Splay(node* o,node* f)
{
    node* x;
    node* y;
    while(o->pre!=f)
    {
        if(o->pre->pre==f)
            Rotate(o);
        else
        {
            x=o->pre;
            y=x->pre;
            int d1,d2;
            if(y->ch[0]==x) d1=0;
            else d1=1;
            if(x->ch[0]==o) d2=0;
            else d2=1;
            if(d1==d2)
            {
                Rotate(x);
                Rotate(o);
            }
            else
            {
                Rotate(o);
                Rotate(o);
            }
        }
    }
    update(o);
    if(f==NULL)
        root=o;
}
void Insert(node* &o,node* pre,int val)
{
    if(o==NULL)
    {
        o=new node;
        o->val=val;
        o->sum=o->num=1;
        o->pre=pre;
        o->ch[0]=o->ch[1]=NULL;
        Splay(o,NULL);
        return ;
    }
    if(val==o->val)
    {
        o->num++;
        o->sum++;
        Splay(o,NULL);//一定要先更新,在伸展
    }
    else if(val<o->val)
        Insert(o->ch[0],o,val);
    else
        Insert(o->ch[1],o,val);
}
node* Pre(node* o,int val)
{
    node* re;
    while(o!=NULL)
    {
        if(o->val<val)
        {
            re=o;
            o=o->ch[1];
        }
        else
            o=o->ch[0];
    }
    return re;
}
node* Suc(node* o,int val)
{
    node* re;
    while(o!=NULL)
    {
        if(o->val>val)
        {
            re=o;
            o=o->ch[0];
        }
        else
            o=o->ch[1];
    }
    return re;
}
void Remove(node* o,int val)
{
    if(o==NULL)
        return ;
    if(val==o->val)
    {
        node* x=Pre(root,val);
        node* y=Suc(root,val);
        Splay(x,NULL);
        Splay(y,x);
        o->num--;
        o->sum--;
        if(o->num==0)
        {
            delete y->ch[0];
            y->ch[0]=NULL;
        }
        Splay(y,NULL);
    }
    else if(val<o->val)
        Remove(o->ch[0],val);
    else
        Remove(o->ch[1],val);
}
int Rank(node* o,int val)
{
    if(o==NULL)
        return 0;
    if(val==o->val)
    {
        return numof(o->ch[0]);
    }
    else if(val<o->val)
    {
        return Rank(o->ch[0],val);
    }
    else
        return numof(o->ch[0])+o->num+Rank(o->ch[1],val);
}
int Kth(node* o,int k)
{
//    if(o==NULL)
//        return 0;
    if(numof(o->ch[0])>=k)
        return Kth(o->ch[0],k);
    else if(numof(o->ch[0])+o->num>=k)
        return o->val;
    else
        return Kth(o->ch[1],k-numof(o->ch[0])-o->num);
}
void Clear(node* o)
{
    if(o==NULL)
        return ;
    Clear(o->ch[0]);
    Clear(o->ch[1]);
    delete o;
}
int main()
{
    int i,d,x;
    int n;
    while(scanf("%d",&n)!=EOF)
    {
        root=NULL;
        Insert(root,NULL,-inf);
        Insert(root,NULL,inf);
        for(i=1;i<=n;i++)
        {
            scanf("%d%d",&d,&x);
            if(d==1)
            {
                Insert(root,NULL,x);
            }
            else if(d==2)
            {
                Remove(root,x);
            }
            else if(d==3)
            {
                printf("%d\n",Rank(root,x));
            }
            else if(d==4)
            {
                printf("%d\n",Kth(root,x+1));
            }
            else if(d==5)
            {
                printf("%d\n",Pre(root,x)->val);
            }
            else
            {
                printf("%d\n",Suc(root,x)->val);
            }
        }
        Clear(root);
    }
    return 0;
}


猜你喜欢

转载自blog.csdn.net/madaidao/article/details/46049317
今日推荐