快速入门Splay

\(splay\) :伸展树(\(Splay Tree\)),也叫分裂树,是一种二叉排序树,它能在\(O(log n)\)内完成插入、查找和删除操作。它由\(Daniel Sleator\)\(Robert Tarjan\)创造,后勃刚对其进行了改进。它的优势在于不需要记录用于平衡树的冗余信息。在伸展树上的一般操作都基于伸展操作。

先让我们看一下一棵二叉搜索树(\(Binary\) \(Search\) \(Tree\))是什么样子的。

如图所示,对任意一棵\(BST\),它有以下性质:

  • 是一棵空树,或者是具有下列性质的二叉树
    • 若它的左子树不空,则左子树上所有结点的值均小于它的根结点的值;
    • 若它的右子树不空,则右子树上所有结点的值均大于它的根结点的值;

根据定义,我们会发现:

  • 这棵树的中序遍历,与其压成数组的升序排序等效

在这样一棵树中,我们可以很容易地维护以下信息:

  • 查询\(x\)数的排名
  • 查询排名为\(x\)的数
  • \(x\)的前驱(前驱定义为小于\(x\),且最大的数)
  • \(x\)的后继(后继定义为大于\(x\),且最小的数)

同样的我们会发现,对于一个固定的数列,它可以形成很多种不同类型的\(BST\)。如果这棵树恰好不太优美,每次维护的复杂度可能会被卡到\(O(N)\)(一条链)。

所以,平衡树这种伟大的数据结构就诞生啦!

顾名思义,平衡树就是一棵可以保持全树平衡的二叉搜索树,以此避免复杂度退化为\(O(N)\)。比较经典的一种平衡树是\(Treap\),它基于的是对一个有序数列,随机出的\(BST\)期望复杂度是\(O(logN)\),通过利用堆的性质来维护其随机性,这个东西我的上一篇博客已经介绍过,不再展开介绍。今天我们要介绍的是另一种经典的平衡树——\(Splay\)

既然是平衡树,\(Splay\)是如何实现其树体平衡的呢?

\(Splay\)的每一个维护操作中,维护结束后当前被维护的点都会被旋转成为树的根节点,这个过程叫做树的伸展。\((Splay)\)。伸展是\(Splay\)的核心操作。与\(Treap\)利用随机出来的优先级进行堆的维护不同,\(Splay\)的大多数操作都要基于伸展操作,这也决定了\(Splay\)相比前者具有更广泛的适用性。

\(Splay\)是怎么保证其复杂度不退化成\(O(N)\)的呢?来看个例子。

在这一棵已经退化成链的\(BST\)中,我们对最底下那个节点进行了一次维护。在这之后,这个节点就开始了向根节点的漫漫伸展之路~

所以在伸展过程结束后,这棵树就再次自发地进化回了一棵正常的树。如果深度更深会更加明显,在一次\(Splay\)以后,它会从\(N\)级别的深度进化为\(logN\)级别。

接着让我们贪心地想一想,假如现在这棵树非常的不优秀。我想要把它卡掉,就应该总是访问它最不优秀的节点。如果最开始它还有很多超级长的链,那么经过几次贪心的访问之后,它的所有链中的最大深度就已经回到\(logN\)了。不管常数怎么样,均摊一下复杂度是没有问题了。

既然这些操作\(Treap\)也能做,为什么不用又快又好写的\(Treap\)呢?因为\(Splay\)在区间操作和\(LCT\)中有其不可替代的作用。具体是什么作用,我也没有学到,等到学了在拿出来讲吧QwQ

讲过了原理,我们可以来看一下代码实现了Qw

inline void push_up (int u) {
    t[u].sz = t[u].cnt;
    t[u].sz += t[t[u].ch[0]].sz;
    t[u].sz += t[t[u].ch[1]].sz;
}

inline void rotate (int x) {
    int y = t[x].fa;
    int z = t[y].fa;
    int d1 = t[y].ch[1] == x;
    int d2 = t[z].ch[1] == y;
    connect (z, x,             d2);
    connect (y, t[x].ch[!d1],  d1);
    connect (x, y           , !d1);         
    push_up (y);
    push_up (x);
} 

这里\(connect\)是一个连边的函数,旋转的原理和\(Treap\)一样,都是要保证其\(BST\)的性质,可以手画一下示意图就明白啦~

inline void splay (int x, int goal) {
    if (x == 0) return;
    while (t[x].fa != goal) {
        int y = t[x].fa;
        int z = t[y].fa;
        int d1 = t[y].ch[1] == x;
        int d2 = t[z].ch[1] == y;
        if (z != goal) {
            if (d1 == d2) {
                rotate (y);
            } else {
                rotate (x);
            }
        }
        rotate (x);
    }
    if (goal == 0) {
        root = x;
    } 
}

核心操作——伸展,可以思考一下:为什么是把\(x\)旋转为\(goal\)的子节点?

剩下的操作,作者很懒,就只贴上代码啦~

inline void find (int key) {
    int u = root;
    if (u == 0) return;
    while (t[u].key != key && t[u].ch[key > t[u].key]) {
        u = t[u].ch[key > t[u].key];
    }
    //找到key对应的节点,并把它旋转到根。
    splay (u, 0);
}

inline void Insert (int key) {
    int u = root, fa = 0;
    while (u != 0 && t[u].key != key) {
        fa = u;//记得记录父亲
        if (key > t[u].key) {
            u = t[u].ch[1];
        } else {
            u = t[u].ch[0];
        }
    }
    if (u != 0) {
        //已有(能查到)
        ++t[u].sz;
        ++t[u].cnt;
    } else {
        //新增
        u = ++max_size;
        t[u].sz = 1;
        t[u].cnt = 1;
        t[u].key = key;
        connect (fa, u, key > t[fa].key); 
    }
    splay (u, 0);
} 

inline int Next (int key, int dir) {
    //dir = 0 -> 前驱
    //dir = 1 -> 后继
    find (key);
    int u = root;
    if (dir == 0 && t[u].key < key) return u;
    if (dir == 1 && t[u].key > key) return u; 
    //如果key值并没有存在于树中:
    u = t[u].ch[dir];
    while (t[u].ch[!dir]) {
        u = t[u].ch[!dir];
    }
    //e.g 如果要找前驱,就先往左一步(保证一定比当前值更小),再一直向右(最大的那个)。
    return u;
}

inline void Delete (int key) {
    int _pre = Next (key, 0);
    int _nxt = Next (key, 1);
    splay (_pre, 0000);
    splay (_nxt, _pre);
    //当前键值key的前驱是_pre, 后继是_nxt
    //_pre被旋转到根节点,_nxt成为_pre的子节点(显然是右)
    //那么当前点一定在_nxt的左边,而且底下没有任何一个点。
    int u = t[_nxt].ch[0];
    if (t[u].cnt > 1) {
        --t[u].cnt;
        splay (u, 0);
    } else {
        t[_nxt].ch[0] = 0;
    }
}

inline int kth (int k) {
    int u = root;
    if (u == 0) return 0;
    while (u != 0) {
        int ls = t[u].ch[0];
        int rs = t[u].ch[1];
        if (k > t[ls].sz + t[u].cnt) {
            k -= t[ls].sz + t[u].cnt;
            u = rs;//格外注意不要写反顺序
        } else if (k <= t[ls].sz) {
            u = ls;
        } else {
            return t[u].key;
        }
    }
    return false;
}

inline int get_rnk (int key) {
    find (key);
    return t[t[root].ch[0]].sz;
}

还有一点需要注意的,\(splay\)在使用前要先\(insert\)一个极大值和一个极小值。否则在\(Next\)函数的查找中,比如只有一个点的话,会出现找不到前驱和后继的情况,也就会导致出莫名其妙的锅。当然,加上极大极小值之后要格外注意对答案的处理。下面给出完整代码,题目P3369 【模板】普通平衡树

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 100010
#define INF 0x7fffffff
using namespace std;

struct Splay_Tree {
    int root, max_size;
    
    struct Splay_Node {
        int sz, fa, cnt, key, ch[2];
    }t[N];
    
    Splay_Tree () {
        root = max_size = 0;
        memset (t, 0, sizeof (t));
    }
    
    inline void connect (int u, int v, int dir) {
        t[u].ch[dir] = v;
        t[v].fa = u;
    }
    
    inline void push_up (int u) {
        t[u].sz = t[u].cnt;
        t[u].sz += t[t[u].ch[0]].sz;
        t[u].sz += t[t[u].ch[1]].sz;
    }
    
    inline void rotate (int x) {
        int y = t[x].fa;
        int z = t[y].fa;
        int d1 = t[y].ch[1] == x;
        int d2 = t[z].ch[1] == y;
        connect (z, x,             d2);
        connect (y, t[x].ch[!d1],  d1);
        connect (x, y           , !d1);         
        push_up (y);
        push_up (x);
    } 
    
    inline void splay (int x, int goal) {
        if (x == 0) return;
        while (t[x].fa != goal) {
            int y = t[x].fa;
            int z = t[y].fa;
            int d1 = t[y].ch[1] == x;
            int d2 = t[z].ch[1] == y;
            if (z != goal) {
                if (d1 == d2) {
                    rotate (y);
                } else {
                    rotate (x);
                }
            }
            rotate (x);
        }
        if (goal == 0) {
            root = x;
        } 
    }
    
    inline void find (int key) {
        int u = root;
        if (u == 0) return;
        while (t[u].key != key && t[u].ch[key > t[u].key]) {
            u = t[u].ch[key > t[u].key];
        }
        splay (u, 0);
    }
    
    inline void Insert (int key) {
        int u = root, fa = 0;
        while (u != 0 && t[u].key != key) {
            fa = u;
            if (key > t[u].key) {
                u = t[u].ch[1];
            } else {
                u = t[u].ch[0];
            }
        }
        if (u != 0) {
            ++t[u].sz;
            ++t[u].cnt;
        } else {
            u = ++max_size;
            t[u].sz = 1;
            t[u].cnt = 1;
            t[u].key = key;
            connect (fa, u, key > t[fa].key); 
        }
        splay (u, 0);
    } 
    
    inline int Next (int key, int dir) {
        find (key);
        int u = root;
        if (dir == 0 && t[u].key < key) return u;
        if (dir == 1 && t[u].key > key) return u; 
        u = t[u].ch[dir];
        while (t[u].ch[!dir]) {
            u = t[u].ch[!dir];
        }
        return u;
    }
    
    inline void Delete (int key) {
        int _pre = Next (key, 0);
        int _nxt = Next (key, 1);
        splay (_pre, 0000);
        splay (_nxt, _pre);
        int u = t[_nxt].ch[0];
        if (t[u].cnt > 1) {
            --t[u].cnt;
            splay (u, 0);
        } else {
            t[_nxt].ch[0] = 0;
        }
    }
    
    inline int kth (int k) {
        int u = root;
        if (u == 0) return 0;
        while (u != 0) {
            int ls = t[u].ch[0];
            int rs = t[u].ch[1];
            if (k > t[ls].sz + t[u].cnt) {
                k -= t[ls].sz + t[u].cnt;
                u = rs;//格外注意 
            } else if (k <= t[ls].sz) {
                u = ls;
            } else {
                return t[u].key;
            }
        }
        return false;
    }
    
    inline int get_rnk (int key) {
        find (key);
        return t[t[root].ch[0]].sz;
    }
}st;

int n, x, opt;

int main () {
//  freopen ("splay.in", "r", stdin);
    scanf ("%d", &n);
    st.Insert (+INF);
    st.Insert (-INF);
    for (int i = 1; i <= n; ++i) {
        scanf ("%d %d", &opt, &x);
        if (opt == 1) {
            st.Insert (x);
        } 
        if (opt == 2) {
            st.Delete (x);
        }
        if (opt == 3) {
            printf ("%d\n", st.get_rnk (x));
        }
        if (opt == 4) {
            printf ("%d\n", st.kth (x + 1));
        }
        if (opt == 5) {
            printf ("%d\n", st.t[st.Next (x, 0)].key);
        }
        if (opt == 6) {
            printf ("%d\n", st.t[st.Next (x, 1)].key);
        }
    }
} 

猜你喜欢

转载自www.cnblogs.com/maomao9173/p/10297014.html