Tyvj 1728 普通平衡树

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define root e[0].ch[1]
const int maxn=100010;
const int INF=233333333;

struct node {
    int v,father; 
    int ch[2];
    int sum;
    int cnt;
};
node e[maxn];
int n, points;

void update(int x) {
    e[x].sum = e[e[x].ch[0]].sum + e[e[x].ch[1]].sum + e[x].cnt;
}

int identify(int x) {
    return e[e[x].father].ch[0] == x ? 0 : 1;
}

void connect(int x,int f,int son) {
    e[x].father=f;
    e[f].ch[son]=x;
}

void rotate(int x) {
    int y=e[x].father;
    int r=e[y].father;
    int rs=identify(y);
    int ys=identify(x);
    int B=e[x].ch[ys^1];
    connect(B, y, ys); connect(y, x, (ys^1)); connect(x, r, rs);
    update(y); update(x);
}

void splay(int at,int to) {
    to=e[to].father;
    while(e[at].father!=to) {
        int up=e[at].father;
        if(e[up].father==to) rotate(at);
        else if(identify(up)==identify(at)) {  //case1
            rotate(up);
            rotate(at);
        }
        else {   //case2
            rotate(at);
            rotate(at);
        }
    }
}

void destroy(int x) {
    e[x].v=e[x].ch[0]=e[x].ch[1]=e[x].sum=e[x].father=e[x].cnt=0;
    if(x==n) n--;
}

int find(int v) {
    int now=root;
    while(true) {
        if(e[now].v==v) {
            splay(now,root);
            return now;
        }
        int next = v < e[now].v ? 0 : 1;
        if(!e[now].ch[next]) return 0;
        now=e[now].ch[next];
    }
}

int crepoint(int v,int father) {
    n++;
    e[n].v=v;
    e[n].father=father;
    e[n].sum=e[n].cnt=1;
    return n;
}

int build(int v) {
    points++;
    if(points==1) {  //空树
        root=n+1;
        crepoint(v,0);
    }
    else {
        int now=root;
        while(true){ 
            e[now].sum++;
            if(v==e[now].v) {
                e[now].cnt++;
                return now;
            }
            int next = v < e[now].v ? 0 : 1;
            if(!e[now].ch[next]) {
                crepoint(v,now);
                e[now].ch[next]=n;
                return n;
            }
            now=e[now].ch[next];
        }
    }
    return 0;
}

void push(int v) {
    int add=build(v);
    splay(add, root);
}

void pop(int v) { 
    int deal=find(v);
    if(!deal) return;
    points--;
    if(e[deal].cnt>1) {
        e[deal].cnt--;
        e[deal].sum--;
        return;
    }
    if(!e[deal].ch[0]) {
        root=e[deal].ch[1];
        e[root].father=0;
    }
    else {
        int lef=e[deal].ch[0];
        while(e[lef].ch[1]) lef=e[lef].ch[1];
        splay(lef,e[deal].ch[0]);
        int rig=e[deal].ch[1];
        connect(rig,lef,1); connect(lef,0,1);
        update(lef); 
    }
    destroy(deal);
}

int Rank(int v) {  
    int ans=0,now=root;
    while(true) {
        if(e[now].v==v) {
            ans=ans+e[e[now].ch[0]].sum+1;
            break;
        }
        if(now==0) return 0;
        if(v<e[now].v) now=e[now].ch[0];
        else {
            ans=ans+e[e[now].ch[0]].sum+e[now].cnt;
            now=e[now].ch[1];
        }
    }
    if(now) splay(now,root);
    return ans;
}

int atrank(int x) {
    if(x>points) return -INF;
    int now=root;
    while(true) {
        int ll=e[now].ch[0];
        if(x<=e[ll].sum) now=e[now].ch[0];
        else {
            if(x<=e[ll].sum+e[now].cnt) break;
            else x-=e[ll].sum+e[now].cnt, now=e[now].ch[1];
        }
    }
    splay(now,root);
    return e[now].v;
}

int upper(int v) {
    int now=root, last=root;
    int result=INF;
    while(now) {
        if(e[now].v>v&&e[now].v<result) result=e[now].v, last=now;
        if(v<e[now].v) now=e[now].ch[0];
        else now=e[now].ch[1];
    } 
    splay(last, root);
    return result;
}

int lower(int v) {
    int now=root, last=root;
    int result=-INF;
    while(now) {
        if(e[now].v<v&&e[now].v>result) result=e[now].v, last=now;
        if(v>e[now].v) now=e[now].ch[1];
        else now=e[now].ch[0];
    }
    splay(last, root);
    return result;
}

int main()
{
    int t;
    cin>>t;
    for(int i=1; i<=t; i++) {
        int op, x;
        cin>>op>>x;
        if(op==1) push(x);
        else if(op==2) {
            pop(x);
        }
        else if(op==3) {
            cout<<Rank(x)<<endl;
        }
        else if(op==4) { 
            cout<<atrank(x)<<endl;
        }
        else if(op==5) {
            cout<<lower(x)<<endl;
        }
        else {
            cout<<upper(x)<<endl;
        }
    }
    return 0; 
}

题解 https://www.luogu.org/problemnew/solution/P3369

猜你喜欢

转载自www.cnblogs.com/ertuan/p/11251024.html