[P3369]普通平衡树(Splay版)

模板,不解释

#include<bits/stdc++.h>
using namespace std;
const int mxn=1e5+5;
int fa[mxn],ch[mxn][2],sz[mxn],cnt[mxn],val[mxn],rt,tot;
namespace Splay {
    void push_up(int x) {
        sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
    };
    void rotate(int x) {
        int y=fa[x],z=fa[y],tp=ch[y][1]==x;
        ch[z][ch[z][1]==y]=x,fa[x]=z; //这里容易写错
        ch[y][tp]=ch[x][tp^1],fa[ch[x][tp^1]]=y;
        ch[x][tp^1]=y,fa[y]=x;
        push_up(y),push_up(x);
    };
    void splay(int x,int gl) {
        while(fa[x]!=gl) {
            int y=fa[x],z=fa[y];
            if(z!=gl)
                (ch[y][1]==x)^(ch[z][1]==y)?rotate(x):rotate(y);
            rotate(x);
        }
        if(gl==0) rt=x;
    };
    void find(int x) {
        int u=rt;
        while(ch[u][x>val[u]]/*这里不一定find的到该值,所以一定要加这句话*/&&x!=val[u]) u=ch[u][x>val[u]]; 
        splay(u,0);
    };
    int kth(int k) {
        int u=rt;
        while(1) {
            if(k<=sz[ch[u][0]]) u=ch[u][0];
            else if(k>sz[ch[u][0]]+cnt[u]) k-=sz[ch[u][0]]+cnt[u],u=ch[u][1];
            else return u;
        }
    };
    void ins(int x) {
        int u=rt,f=0;
        while(val[u]!=x&&u) f=u,u=ch[u][x>val[u]];
        if(u==0) {
            u=++tot;
            if(f) ch[f][x>val[f]]=u; 
            val[u]=x; fa[u]=f;
            cnt[u]=sz[u]=1;
        }
        else ++cnt[u];
        splay(u,0);
    };
    int pre(int x) {
        find(x);
        if(val[rt]<x) return rt;
        int u=ch[rt][0];
        while(ch[u][1]) u=ch[u][1];
        return u;
    };
    int nxt(int x) {
        find(x);
        if(val[rt]>x) return rt;
        int u=ch[rt][1];
        while(ch[u][0]) u=ch[u][0];
        return u;
    };
    void erase(int x) {
        find(x);
        if(cnt[rt]>1) --cnt[rt];
        else {
            int l=pre(x),r=nxt(x); //这里容易写错
            splay(l,0); splay(r,l);
            ch[r][0]=0;
        }
    };
}

int main()
{
    using namespace Splay;
    int t,opt,x;
    scanf("%d",&t);
    ins(-1000000000),ins(1000000000);//切记插入端点,否则前驱后继不好求
    while(t--) {
        scanf("%d %d",&opt,&x);
        if(opt==1) ins(x);
        else if(opt==2) erase(x);
        else if(opt==3) find(x),printf("%d\n",sz[ch[rt][0]]);
        else if(opt==4) printf("%d\n",val[kth(x+1)]);
        else if(opt==5) printf("%d\n",val[pre(x)]);
        else printf("%d\n",val[nxt(x)]);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/list1/p/10362914.html