伸展树模板

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/weixin_37517391/article/details/82834483

伸展树模板


#include <iostream>
#include <cstdio>

#define pr(x) std::cout << #x << " : " << x << std::endl

class SplayTree{
public:
    struct Node{
        int val;
        int size;
        int tag;
        Node *father;
        Node *son[2];

        Node(int val,Node *father) {
            this -> son[0] = this -> son[1] = NULL;
            this -> tag = 0;
            this -> val = val;
           this -> father = father;
        }

    }*root;
    SplayTree(){
        root = NULL;
    }
    inline bool son(Node *f,Node *s) {
        return f && f -> son[1] == s;
    }

    inline void connect(Node *f,Node *s,bool k) {
        if(f == NULL)
            root = s;
        else
            f -> son[k] = s;

        if(s != NULL)
            s -> father = f;
    }

    inline void maintain(Node *t) {
        t -> size = 1;
        if(t -> son[0]) t -> size += t -> son[0] -> size;
        if(t -> son[1]) t -> size += t -> son[1] -> size;
    }

    inline void rotate(Node *t) {
        Node *f = t -> father;
        Node *g = f -> father;
        pushdown(f);
        pushdown(t);
        bool a = son(f,t);
        connect(f,t->son[!a],a);
        connect(g,t,son(g,f));
        connect(t,f,!a);
        maintain(f);
        maintain(t);
    }

    inline void splay(Node *t,Node *p) {
        while(p != t -> father) {
            Node *f = t -> father;
            Node *g = f -> father;
            if(g == p) {
                rotate(t);
            }
            else {
                if(son(g,f) ^ son(f,t))
                    rotate(t),rotate(t);
                else
                    rotate(f),rotate(t);
            }
        }
    }

    inline void insert(int val) {
        if(root == NULL) {
            root = new Node(val,NULL);
            return ;
        }
        for(Node* t = root;t;t = t -> son[val > t -> val]) {
            pushdown(t);
            bool a = val > t -> val;
            if(t -> son[a] == NULL) {
                t -> son[a] = new Node(val,t);
                //maintain(t);
                splay(t -> son[a],NULL);
                break;
            }
        }
    }

    inline void erase(int k) {
        if(!root || root -> size <= 1) {
            root = NULL;
            return ;
        }
        if(k == 1) {
            splay(select(root, k),NULL);
            pushdown(root);
            root = root -> son[1];
            root -> father = NULL;
        }
        else if(k == root -> size) {
            splay(select(root, k),NULL);
            pushdown(root);
            root = root -> son[0];
            root -> father = NULL;
        }
        else {
            splay(select(root, k-1),NULL);
            splay(select(root, k+1),root);
            pushdown(root);
            pushdown(root -> son[1]);
            root -> son[1] -> son[0] = NULL;
            maintain(root -> son[1]);
            maintain(root);
        }
    }

    inline Node* select(Node *t,int k) {
        pushdown(t);
        int left = 0;
        if(t -> son[0]) left += t -> son[0] -> size;
        if(k == left + 1)
            return t;
        if(k <= left)
            return select(t -> son[0],k);
        else
            return select(t -> son[1],k - left - 1);
    }


    inline Node* lower_bound(int val) {
        //不适用于翻转操作,因此不需pushdown
        if(!root)
            return NULL;
        Node* ans = NULL;
        for(Node *t = root;t;t = t -> son[val > t -> val]) {
            if(t -> val >= val) ans = t;
        }
        return ans;
    }

    inline void reverse(Node *t) {
        if(t)
            t -> tag ^= 1;
    }

    inline void pushdown(Node *t) {
        if(t && t -> tag) {
            std::swap(t -> son[0],t -> son[1]);
            reverse(t -> son[0]);
            reverse(t -> son[1]);
            t -> tag ^= 1;
        }
    }
};
int n,m;
int main()
{
    SplayTree tree;
    std::ios::sync_with_stdio(false);
    std::cin >> m;
    while(m --) {
        int tp;
        std::cin >> tp;
        if(tp == 1) {
            //插入一个数x
            int x;
            std::cin >> x;
            tree.insert(x);
        }
        else if(tp == 2) {
            //删除数值为x的数
            int x;
            std::cin >> x;
            auto res = tree.lower_bound(x);
            if(res && res -> val == x) {
                tree.splay(res,NULL);
                int k = (tree.root -> son[0] ? tree.root -> son[0] -> size:0) + 1;
                tree.erase(k);
            }
        }
        else if(tp == 3) {
            //输出数值为x的数的排名
            int x;
            std::cin >> x ;
            auto res = tree.lower_bound(x);
            if(res == NULL) {
                std::cout << res->size + 1 << std::endl;
                continue;
            }
            tree.splay(res,NULL);
            int k = 1;
            if(res -> son[0]) k += res -> son[0] -> size;
            std::cout << k << std::endl;
        }
        else if(tp == 4) {
            //找到第k个数字
            int k;
            std::cin >> k;
            std::cout << tree.select(tree.root,k)->val << std::endl;
        }
        else if(tp == 5) {
            //输出比x小的最大的数
            int x;
            std::cin >> x;
            int ans = -1000000000;
            for(auto t = tree.root;t;t = t -> son[x > t -> val]) {
                if(t -> val < x && t -> val > ans) ans = t -> val;
            }
            std::cout << ans << std::endl;
        }
        else {
            //输出比x大的最小的数
            int x;
            std::cin >> x;
            int ans = 1000000000;
            for(auto t = tree.root;t;t = t -> son[x >= t -> val]) {
                if(t -> val > x && t -> val < ans) ans = t -> val;
            }
            std::cout << ans << std::endl;
        }
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_37517391/article/details/82834483