【洛谷 P3369】 【模板】普通平衡树 --- Splay

传送门

题目描述

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

  1. 插入x数

  2. 删除x数(若有多个相同的数,因只删除一个)

  3. 查询x数的排名(排名定义为比当前数小的数的个数+1。若有多个相同的数,因输出最小的排名)

  4. 查询排名为x的数

  5. 求x的前驱(前驱定义为小于x,且最大的数)

  6. 求x的后继(后继定义为大于x,且最小的数)

分析

就当作模板吧。
1. 注意updata的位置【 位置不同,效率有区别。见下 】
2. 内存优化:对于被删除的节点,用 栈/队列 保存(对于子树,递归处理),当新建节点时,优先使用之前删除的节点(注意初始化)

代码

版本一 552ms

#include <cstdio>
#include <cstdlib>
#include <algorithm>

#define IL inline
#define maxn 100005

using namespace std;

IL int read()
{
    int k = 1;
    int sum = 0;
    char c = getchar();

    for(; '0' > c || c > '9'; c = getchar())
        if(c == '-') k = -1;

    for(; '0' <= c && c <= '9'; c = getchar())
        sum =sum * 10 + c - '0';

    return sum*k;
}

struct node
{
    node *father;
    node *son[2];
    int val;
    int size;
    int cnt;
};
node *root, *stk[maxn], tree[maxn];
int top=maxn;

IL node* newnode(int val, node *f)
{
    node *p = stk[-- top];
    p->val = val;
    p->father = f;
    p->son[0] = p->son[1] = 0;
    p->size = 1;
    p->cnt = 0;

    return p;   
} 

IL void freenode(node *p)
{
    stk[top++] = p;
}

IL bool son(node *f, node *p)
{
    if(!f) return 0;
    return f->son[1] == p;
}

IL int size(node *p)
{
    return (p != 0)? (p->size): 0;
} 

IL void updata(node *p)
{
    p->size = p->cnt;
    if(p->son[0]) p->size += size(p->son[0]);
    if(p->son[1]) p->size += size(p->son[1]);
}

IL void connect(node *f, node *p, int k)
{
    if(!f) root = p; else f->son[k] = p;
    if(p) p->father = f;
}

IL void rotate(node *p)
{
    node *f = p->father;
    node *g = f->father;
    bool x = son(f, p), y = !x;

    connect(f, p->son[y], x);
    connect(g, p, son(g, f));
    connect(p, f, y);

    updata(f);
    updata(p);
}

IL void splay(node *p, node *q)
{
    for(node *f , *g; p->father != q;)
    {
        f = p->father;
        g = f->father;

        if(g == q) rotate(p); else
        {
            if(son(g, f) ^ son(f, p))
                rotate(p), rotate(p);
            else
                rotate(f), rotate(p);
        }
    }
}

IL void insert(int val)
{
    if(!root) root = newnode(val, 0);

    for(node *p = root; p; p = p->son[val > p->val])
    {
        if(p->val == val)
        {
            ++ (p->cnt);
            ++ (p->size);
            splay(p, 0);
            return ;
        }
        if(!p->son[val > p->val])
            p->son[val > p->val] = newnode(val ,p);
    }
}

IL node* find(int val)
{
    for(node *p = root; p; p = p->son[val > p->val])
    if(p->val == val)
    {
        splay(p, 0);
        return p;
    }
    return 0;
}

IL node* find_pre(node *p)
{
    if(p != root) splay(p, 0);
    p = p->son[0];
    if(!p) return p;
    for(; p->son[1]; p = p->son[1]);
    return p;
}

IL node* find_nxt(node *p)
{
    if(p != root) splay(p, 0);
    p = p->son[1];
    if(!p) return p;
    for(; p->son[0]; p = p->son[0]);
    return p;
}

IL void erase(int val)
{
    node *p = find(val);
    if(!p) return ;
    -- (p->cnt);
    -- (p->size);
    if(p->cnt) return ;
    if(!(p->son[0]) || !(p->son[1]))
    {
        bool x = (p->son[0])? 0: 1;
        freenode(p);
        root = p->son[x];
        if(root) root->father = 0;
    }else
    {
        node *t = find_pre(root);
        freenode(p);
        splay(t, root);
        root = t;
        if(root->father) root->father = 0;
        connect(t, p->son[1], 1);
    }
}

IL int rank(int val)
{
    node *p = find(val);
    splay(p, 0);
    return size(p->son[0]) + 1;
}

IL int rerank(int t)
{
    node *p = root;
    for(; 0 >= t - size(p->son[0]) || t - size(p->son[0]) > p->cnt; )
    if(size(p->son[0]) >= t)
    {
        p = p->son[0];
    }else
    {
        t -= size(p->son[0]) + p->cnt;
        p = p->son[1];
    }
    splay(p, 0);
    return p->val;
}

int main()
{
    int n = read();
    top = n;

    for(int i = 0; i < n; ++i) stk[i] = &tree[i];

    for(int i = 1, k, x; i <= n; ++ i)
    {
        k = read();
        if(k == 1)
        {
            insert(read());
        }else
        if(k == 2)
        {
            erase(read());
        }else
        if(k == 3)
        {
            printf("%d\n", rank(read()));
        }else
        if(k == 4)
        {
            printf("%d\n", rerank(read()));
        }else
        if(k == 5)
        {
            int x = read();
            node *p = find(x);
            bool f = 0;
            if(!p) { insert(x); p = find(x); f = 1; }
            printf("%d\n",  find_pre(p)->val);
            if(f) erase(x);
        }else
        {
            int x = read();
            node *p = find(x);
            bool f = 0;
            if(!p) { insert(x); p = find(x); f = 1; }
            printf("%d\n", find_nxt(p)->val);
            if(f) erase(x);
        }
    }
    return 0;
}

版本二:480ms

#include <cstdio>
#include <cstdlib>
#include <algorithm>

#define IL inline
#define maxn 100005

using namespace std;

IL int read()
{
    int k = 1;
    int sum = 0;
    char c = getchar();

    for(; '0' > c || c > '9'; c = getchar())
        if(c == '-') k = -1;

    for(; '0' <= c && c <= '9'; c = getchar())
        sum =sum * 10 + c - '0';

    return sum*k;
}

struct node
{
    node *father;
    node *son[2];
    int val;
    int size;
    int cnt;
};
node *root, *stk[maxn], tree[maxn];
int top=maxn;

IL node* newnode(int val, node *f)
{
    node *p = stk[-- top];
    p->val = val;
    p->father = f;
    p->son[0] = p->son[1] = 0;
    p->size = 1;
    p->cnt = 0;

    return p;   
} 

IL void freenode(node *p)
{
    stk[top++] = p;
}

IL bool son(node *f, node *p)
{
    if(!f) return 0;
    return f->son[1] == p;
}

IL int size(node *p)
{
    return (p != 0)? (p->size): 0;
} 

IL void updata(node *p)
{
    p->size = p->cnt;
    if(p->son[0]) p->size += size(p->son[0]);
    if(p->son[1]) p->size += size(p->son[1]);
}

IL void connect(node *f, node *p, int k)
{
    if(!f) root = p; else f->son[k] = p;
    if(p) p->father = f;
}

IL void rotate(node *p)
{
    node *f = p->father;
    node *g = f->father;
    bool x = son(f, p), y = !x;

    connect(f, p->son[y], x);
    connect(g, p, son(g, f));
    connect(p, f, y);

    updata(f);
}

IL void splay(node *p, node *q)
{
    for(node *f , *g; p->father != q;)
    {
        f = p->father;
        g = f->father;

        if(g == q) rotate(p); else
        {
            if(son(g, f) ^ son(f, p))
                rotate(p), rotate(p);
            else
                rotate(f), rotate(p);
        }
    }
    updata(p);
}

IL void insert(int val)
{
    if(!root) root = newnode(val, 0);

    for(node *p = root; p; p = p->son[val > p->val])
    {
        if(p->val == val)
        {
            ++ (p->cnt);
            ++ (p->size);
            splay(p, 0);
            return ;
        }
        if(!p->son[val > p->val])
            p->son[val > p->val] = newnode(val ,p);
    }
}

IL node* find(int val)
{
    for(node *p = root; p; p = p->son[val > p->val])
    if(p->val == val)
    {
        splay(p, 0);
        return p;
    }
    return 0;
}

IL node* find_pre(node *p)
{
    if(p != root) splay(p, 0);
    p = p->son[0];
    if(!p) return p;
    for(; p->son[1]; p = p->son[1]);
    return p;
}

IL node* find_nxt(node *p)
{
    if(p != root) splay(p, 0);
    p = p->son[1];
    if(!p) return p;
    for(; p->son[0]; p = p->son[0]);
    return p;
}

IL void earse(int val)
{
    node *p = find(val);
    if(!p) return ;
    -- (p->cnt);
    -- (p->size);
    if(p->cnt) return ;
    if(!(p->son[0]) || !(p->son[1]))
    {
        bool x = (p->son[0])? 0: 1;
        freenode(p);
        root = p->son[x];
        if(root) root->father = 0;
    }else
    {
        node *t = find_pre(root);
        freenode(p);
        splay(t, root);
        root = t;
        if(root->father) root->father = 0;
        connect(t, p->son[1], 1);
    }
}

IL int rank(int val)
{
    node *p = find(val);
    splay(p, 0);
    return size(p->son[0]) + 1;
}

IL int rerank(int t)
{
    node *p = root;
    for(; 0 >= t - size(p->son[0]) || t - size(p->son[0]) > p->cnt; )
    if(size(p->son[0]) >= t)
    {
        p = p->son[0];
    }else
    {
        t -= size(p->son[0]) + p->cnt;
        p = p->son[1];
    }
    splay(p, 0);
    return p->val;
}

int main()
{
    int n = read();
    top = n;

    for(int i = 0; i < n; ++i) stk[i] = &tree[i];

    for(int i = 1, k, x; i <= n; ++ i)
    {
        k = read();
        if(k == 1)
        {
            insert(read());
        }else
        if(k == 2)
        {
            earse(read());
        }else
        if(k == 3)
        {
            printf("%d\n", rank(read()));
        }else
        if(k == 4)
        {
            printf("%d\n", rerank(read()));
        }else
        if(k == 5)
        {
            int x = read();
            node *p = find(x);
            bool f = 0;
            if(!p) { insert(x); p = find(x); f = 1; }
            printf("%d\n",  find_pre(p)->val);
            if(f) earse(x);
        }else
        {
            int x = read();
            node *p = find(x);
            bool f = 0;
            if(!p) { insert(x); p = find(x); f = 1; }
            printf("%d\n", find_nxt(p)->val);
            if(f) earse(x);
        }

    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_27121257/article/details/79395627