bzoj3224: Tyvj 1728 普通平衡树/洛谷P3369 【模板】普通平衡树(Treap/SBT)

题目

1.splay

题解

#include<cstdio>
const int N=100003,inf=1<<30;
int n,x,opt;
struct Splay{
    #define rt e[0].ch[1]
    struct node{
        int ch[2],fa,v,sum,same;
    }e[N];
    int n,point;
    void update(int x){
        e[x].sum=e[e[x].ch[0]].sum+e[e[x].ch[1]].sum+e[x].same;
    }
    int id(int x){
        return x==e[e[x].fa].ch[1];
    }
    void con(int x,int y,int son){
        e[x].fa=y;
        e[y].ch[son]=x;
    }
    void rotate(int x){
        int y=e[x].fa,z=e[y].fa,yson=id(x),zson=id(y);
        con(e[x].ch[yson^1],y,yson);
        con(y,x,yson^1);
        con(x,z,zson);
        update(y);
        update(x);
    }
    void splay(int at,int to){
        to=e[to].fa;
        while (e[at].fa!=to){
            int up=e[at].fa;
            if (e[up].fa==to) rotate(at);
            else if (id(up)==id(at)) rotate(up),rotate(at);
            else rotate(at),rotate(at);
        }
    }
    void create(int v,int fa){
        e[++n].v=v;
        e[n].fa=fa;
        e[n].sum=e[n].same=1;
    }
    void destroy(int x){
        e[x].sum=e[x].same=e[x].ch[0]=e[x].ch[1]=e[x].fa=e[x].v=0;
        if (x==n) n--;
    }
    int find(int v){
        int now=rt;
        while (1){
            if (e[now].v==v){
                splay(now,rt);
                return now;
            }
            int ne=v>=e[now].v;
            now=e[now].ch[ne];
            if (!now) return 0;
        }
    }
    int build(int v){
        point++;
        if (n==0){
            rt=1;
            create(v,0);
            return 0;
        }
        int now=rt;
        while (1){
            e[now].sum++;
            if (e[now].v==v){
                e[now].same++;
                return now;
            }
            int ne=v>=e[now].v;
            if (!e[now].ch[ne]){
                create(v,now);
                return e[now].ch[ne]=n;
            }
            now=e[now].ch[ne];
        }
    }
    void push(int v){
        splay(build(v),rt);
    }
    void pop(int v){
        int x=find(v);
        if (!x) return;
        point--;
        if (e[x].same>1){
            e[x].same--;
            e[x].sum--;
            return;
        }
        if (!e[x].ch[0]){
            rt=e[x].ch[1];
            e[rt].fa=0;
        }else{
            int le=e[x].ch[0];
            while (e[le].ch[1]) le=e[le].ch[1];
            splay(le,e[x].ch[0]);
            int ri=e[x].ch[1];
            con(ri,le,1);con(le,0,1);
            update(le);
        }
        destroy(x);
    }
    int rank(int v){
        int now=rt,ans=0;
        while (1){
            if (e[now].v==v) return ans+=e[e[now].ch[0]].sum+1;
            if (v<e[now].v) now=e[now].ch[0];
            else ans+=e[e[now].ch[0]].sum+e[now].same,now=e[now].ch[1];
            if (!now) return 0;
        }
    }
    int arank(int x){
        if (x>point) return -inf;
        int now=rt;
        while (1){
            int ri=e[now].sum-e[e[now].ch[1]].sum;//e[e[now].ch[0]].sum+e[now].same
            if (e[e[now].ch[0]].sum<x && x<=ri) break;
            if (x<=ri) now=e[now].ch[0];
            else x-=ri,now=e[now].ch[1];
        }
        splay(now,rt);
        return e[now].v;
    }
    int lower(int v){
        int now=rt,res=-inf;
        while (now){
            if (e[now].v<v && e[now].v>res) res=e[now].v;
            int ne=v>e[now].v;
            now=e[now].ch[ne];
        }
        return res;
    }
    int upper(int v){
        int now=rt,res=inf;
        while (now){
            if (e[now].v>v && e[now].v<res) res=e[now].v;
            int ne=v>=e[now].v;
            now=e[now].ch[ne];
        }
        return res;
    }
    #undef rt
}S;
inline int read(){
    char c;int x=0,f=1;
    do{c=getchar();if(c=='-')f=-1;}while(c<48||c>57);
    do x=(x<<1)+(x<<3)+(c^48),c=getchar();while(c>=48&&c<=57);
    return f*x;
}
int main(){
    scanf("%d",&n);
    S.push(-inf);
    S.push(inf);
    while (n--){
        opt=read();x=read();
        if (opt==1) S.push(x);
        if (opt==2) S.pop(x);
        if (opt==3) printf("%d\n",S.rank(x)-1);
        if (opt==4) printf("%d\n",S.arank(x+1));
        if (opt==5) printf("%d\n",S.lower(x));
        if (opt==6) printf("%d\n",S.upper(x));
    }
}

2.treap

题解

#include<bits/stdc++.h>
using namespace std;
const int N=100003,inf=1<<30;
int opt,n,x,rt;
struct Treap{
    struct node{
        int ch[2],sum,w,r;
    }e[N];
    int n;
    void update(int x){
        e[x].sum=e[e[x].ch[0]].sum+e[e[x].ch[1]].sum+1;
    }
    void rotate(int &x,int son){
        int t=e[x].ch[son];
        e[x].ch[son]=e[t].ch[son^1];
        e[t].ch[son^1]=x;
        update(x);update(t);
        x=t;
    }
    void push(int v,int &x){
        if (!x){
            x=++n;
            e[x].sum=1;
            e[x].w=v;
            e[x].r=rand();
            return;
        }
        e[x].sum++;
        int ne=v>e[x].w;
        push(v,e[x].ch[ne]);
        if (e[e[x].ch[ne]].r<e[x].r) rotate(x,ne);
    }
    void pop(int v,int &x){
        if (v==e[x].w){
            if (!e[x].ch[0] || !e[x].ch[1]){
                x=e[x].ch[0]+e[x].ch[1];
                return;
            }
            int ne=e[e[x].ch[0]].r>e[e[x].ch[1]].r;
            rotate(x,ne);
            pop(v,e[x].ch[ne^1]);
        }else{
            int ne=v>e[x].w;
            pop(v,e[x].ch[ne]);
        }
        update(x);
    }
    int rank(int v,int x){
        if (!x) return 1;
        if (v<=e[x].w) return rank(v,e[x].ch[0]);
        return rank(v,e[x].ch[1])+e[e[x].ch[0]].sum+1;
    }
    int arank(int v,int x){
        if (e[e[x].ch[0]].sum==v-1) return e[x].w;
        if (v<=e[e[x].ch[0]].sum) return arank(v,e[x].ch[0]);
        return arank(v-e[e[x].ch[0]].sum-1,e[x].ch[1]);
    }
    int lower(int v,int x){
        if (!x) return -inf;
        if (e[x].w<v) return max(e[x].w,lower(v,e[x].ch[1]));
        return lower(v,e[x].ch[0]);
    }
    int upper(int v,int x){
        if (!x) return inf;
        if (e[x].w>v) return min(e[x].w,upper(v,e[x].ch[0]));
        return upper(v,e[x].ch[1]);
    }
}S;
inline int read(){
    char c;int x=0,f=1;
    do{c=getchar();if(c=='-')f=-1;}while(c<48||c>57);
    do x=(x<<1)+(x<<3)+(c^48),c=getchar();while(c>=48&&c<=57);
    return f*x;
}
int main(){
    srand(time(0));
    scanf("%d",&n);
    S.push(-inf,rt);
    S.push(inf,rt);
    while (n--){
        opt=read();x=read();
        if (opt==1) S.push(x,rt);
        if (opt==2) S.pop(x,rt);
        if (opt==3) printf("%d\n",S.rank(x,rt)-1);
        if (opt==4) printf("%d\n",S.arank(x+1,rt));
        if (opt==5) printf("%d\n",S.lower(x,rt));
        if (opt==6) printf("%d\n",S.upper(x,rt));
    }
}

猜你喜欢

转载自blog.csdn.net/xumingyang0/article/details/80705821