浅谈伸展树Splay

普通平衡树

Description

设计一种数据结构,支持插入元素,删除元素,查询值为val的元素的排名,查询排名为rnk的值,查询x的前驱、后驱

Solution

Splay的基本操作,熟悉一下Splay,这些操作事实上与Treap也能解决。
为了实现Splay,我们有如下定义及实现方法。
1.定义结构体Splay,成员同Treap的定义
2.定义update函数用于维护节点的sum值
直接加法运算即可
3.定义connect函数用于建立父子之间的关系
直接赋值即可
4.定义rotate函数用于Splay的旋转
以左旋为例,假设旋转的节点为x,他的父亲为y,他的右子树为B,他的祖父为z,那么我们令B的父亲为y,y的父亲为x,x的父亲为z
5.定义splay函数用于实现平衡树的伸展操作
假设当前节点为x,需要伸展到的节点为to,假定x的父亲、祖父为y,z,那么假设x,y都是其父亲的左/ 右儿子,那么我们旋转y,x,否则我们旋转两次x
6.定义insert函数用于实现元素的插入
首先根据BST的性质找到插入的位置,如果这个位置有节点那么cnt++,否则新建节点并赋值
7.定义find函数用于找到某个值的节点的编号并且将这个节点伸展到树根
根据BST的性质找到位置并调用splay函数实现
8.定义calc函数用于计算并返回排名为x的数
根据BST的性质,假定我们要查询排名为x的元素,那么假设x小于等于左子树的大小那么进入左子树,如果x大于左子树+该节点重复的次数那么 进入右子树,否则返回当前节点
9.定义query函数用于计算并返回某个元素的前驱后驱的编号
首先调用find()使得目标节点伸展到树根上,如果目标节点的值恰好符合题意那么直接返回,否则利用BST的性质找到答案
10.定义del函数用于删除某个元素
假设删除的元素值为x,那么我们找到他的前后驱,执行两次伸展,直接删除即可完成

Code

  1 #include <bits/stdc++.h>
  2 using namespace std;
  3 const int INF = 2147483647;
  4 inline int read() {
  5     int ret = 0, op = 1;
  6     char c = getchar();
  7     while (!isdigit(c)) {
  8         if (c == '-') op = -1; 
  9         c = getchar();
 10     }
 11     while (isdigit(c)) {
 12         ret = ret * 10 + c - '0';
 13         c = getchar();
 14     }
 15     return ret * op;
 16 }
 17 struct Splay {
 18     int ch[2];
 19     int cnt, sum;
 20     int val, fa;
 21 } a[100010];
 22 int tot, root;
 23 void update(int now) {
 24     a[now].sum = a[a[now].ch[0]].sum + a[a[now].ch[1]].sum + a[now].cnt;
 25 }
 26 void connect(int x, int fa, int op) {
 27     a[x].fa = fa;
 28     a[fa].ch[op] = x;
 29 }
 30 void rotate(int x) {
 31     int y = a[x].fa;
 32     int z = a[y].fa;
 33     int xson = a[y].ch[1] == x ? 1 : 0;
 34     int yson = a[z].ch[1] == y ? 1 : 0;
 35     int B = a[x].ch[xson ^ 1];
 36     connect(B, y, xson); connect(y, x, xson ^ 1); connect(x, z, yson);
 37     update(y); update(x);
 38 }
 39 void splay(int from, int to) {
 40     while (a[from].fa != to) {
 41         int y = a[from].fa;
 42         int z = a[y].fa;
 43         if (z != to)
 44             (a[y].ch[0] == from) ^ (a[z].ch[0] == y) ? update(from) : update(y);
 45         rotate(from);
 46     }
 47     if (to == 0) root = from; 
 48 }
 49 void insert(int val) {
 50     int now = root, fa = 0;
 51     while (now && a[now].val != val) {
 52         fa = now;
 53         now = a[now].ch[val > a[now].val];
 54     }
 55     if(now) {
 56         a[now].cnt++;
 57     }
 58     else {
 59         a[now = ++tot].val = val;
 60         a[tot].sum = a[tot].cnt = 1;
 61         a[tot].fa = fa;
 62         a[tot].ch[0] = a[tot].ch[1] = 0;
 63         if (fa) a[fa].ch[val > a[fa].val] = tot;
 64     }
 65     splay(now, 0);
 66 }
 67 void find(int x) {
 68     int now = root;
 69     if (now == 0) return ;
 70     while (a[now].val != x && a[now].ch[a[now].val < x]) now = a[now].ch[a[now].val < x];
 71     splay(now, 0);
 72 }
 73 int calc(int x) {
 74     int now = root;
 75     if (a[now].sum < x) return 0;
 76     while (1) {
 77         int y = a[now].ch[0];
 78         if (x > a[y].sum + a[now].cnt) {
 79             x -= a[y].sum + a[now].cnt;
 80             now = a[now].ch[1];
 81         }
 82         else if (x <= a[y].sum) now = y;
 83         else return a[now].val;
 84     } 
 85 }
 86 int query(int x, int op) {
 87     find(x);
 88     int now = root;
 89     if ((op && a[now].val > x) || (a[now].val < x && !op)) return now;
 90     now = a[now].ch[op];
 91     while (a[now].ch[op ^ 1]) now = a[now].ch[op ^ 1];
 92     return now;
 93 }
 94 void del(int x) {
 95     int pre = query(x, 0);
 96     int nxt = query(x, 1);
 97     splay(pre, 0); splay(nxt, pre);
 98     int now = a[nxt].ch[0];
 99     if (a[now].cnt > 1) {
100         a[now].cnt--;
101         splay(now, 0);
102         return ;
103     }
104     else a[nxt].ch[0] = 0;
105 }
106 int main() {
107     insert(-INF);
108     insert(INF);
109     int m = read();
110     while (m--) {
111         int op = read(), x = read();
112         if (op == 1) {
113             insert(x);
114         }
115         else if (op == 2) {
116             del(x);
117         }
118         else if (op == 3) {
119             find(x);
120             printf("%d\n", a[a[root].ch[0]].sum);
121         }
122         else if (op == 4) {
123             printf("%d\n", calc(x + 1));
124         }
125         else if (op == 5) {
126             printf("%d\n", a[query(x, 0)].val);
127         }
128         else {
129             printf("%d\n", a[query(x, 1)].val);
130         }
131     } 
132     return 0;
133 }
AC Code

文艺平衡树

Description

 写一种数据结构,维护一个序列,并支持区间翻转

Solution

Splay的经典操作:维护区间翻转
对于区间翻转这种操作,由于原序列不能排序,所以我们不能建立一棵权值树,所以我们按照节点的编号建立一棵平衡树。
相关函数的定义如下:
1.定义update函数用于维护节点的sum值
同上
2.定义splay函数用于实现平衡树的伸展操作
同上
3.定义find函数用于找到某个值的节点的编号
根据BST的性质找到位置即可
4.定义rotate函数用于用于Splay的旋转
同上
5.定义build函数用于建立平衡树
确切的讲,我们仿照线段树的建树方式,首先建立当前节点,然后递归建立其左右儿子,然后调用update()维护信息即可
6.定义reverse函数用于实现区间翻转
假定我们翻转的区间为[l,r],那么我们调用splay()将l-1伸展到根节点,再调用一次splay()将r+1伸展到根节点的右儿子,这样我们只需要在根节点的右儿子的左儿子打一个标记即可。
7.定义pushdown函数用于下放标记
类比线段树,每一次翻转操作我们都会在相应的区间打标记,下放标记时将当前节点的标记清空,同时交换两个儿子,并且更新儿子的标记即可
8.定义connect函数用于建立父子之间的关系
同上
9.定义dfs函数用于输出最后的答案
根据BST的性质,我们对平衡树进行一次中序遍历即可输出最终的序列

Code

  1 #include <bits/stdc++.h>
  2 using namespace std;
  3 const int INF = 2147483647;
  4 inline int read() {
  5     int ret = 0, op = 1;
  6     char c = getchar();
  7     while (!isdigit(c)) {
  8         if (c == '-') op = -1; 
  9         c = getchar();
 10     }
 11     while (isdigit(c)) {
 12         ret = ret * 10 + c - '0';
 13         c = getchar();
 14     }
 15     return ret * op;
 16 }
 17 int n, m, in[100010], root, tot;
 18 struct Splay {
 19     int val, sum, fa, ch[2], tag, cnt;
 20 } a[100010];
 21 void update(int now) {
 22     if (!now) return ;
 23     a[now].sum = a[now].cnt;
 24     if (a[now].ch[0]) a[now].sum += a[a[now].ch[0]].sum;
 25     if (a[now].ch[1]) a[now].sum += a[a[now].ch[1]].sum;
 26 }
 27 void pushdown(int now) {
 28     if (now && a[now].tag) {
 29         a[a[now].ch[0]].tag ^= 1;
 30         a[a[now].ch[1]].tag ^= 1;
 31         swap(a[now].ch[1], a[now].ch[0]);
 32         a[now].tag = 0;
 33     }
 34 }
 35 void connect(int x, int fa, int op) {
 36     a[fa].ch[op] = x;
 37     a[x].fa = fa;
 38 }
 39 void rotate(int x) {
 40     int y = a[x].fa;
 41     int z = a[y].fa;
 42     pushdown(x);
 43     pushdown(y);
 44     int xson = a[y].ch[1] == x ? 1 : 0;
 45     int yson = a[z].ch[1] == y ? 1 : 0;
 46     int B = a[x].ch[xson ^ 1];
 47     connect(B, y, xson); connect(y, x, xson ^ 1); connect(x, z, yson);
 48     update(y), update(x);
 49 }
 50 void splay(int from, int to) {
 51     while (a[from].fa != to) {
 52         int y = a[from].fa;
 53         int z = a[y].fa;
 54         if (z != to) (a[y].ch[0] == from) ^ (a[z].ch[0] == y) ? rotate(from) : rotate(y);
 55         rotate(from); 
 56     }
 57     if (to == 0) root = from;
 58 }
 59 int build(int fa, int l, int r) {
 60     if (l > r) return 0;
 61     int mid = l + r >> 1;
 62     int now = ++tot;
 63     a[now].val = in[mid];
 64     a[now].cnt++;
 65     a[now].fa = fa;
 66     a[now].sum++;
 67     a[now].ch[0] = 0;
 68     a[now].ch[1] = 0;    
 69     a[now].ch[0] = build(now, l, mid - 1);
 70     a[now].ch[1] = build(now, mid + 1, r);
 71     update(now);
 72     return now;
 73 }
 74 int find(int x) {
 75     int now = root;
 76     while (1) {
 77         pushdown(now);
 78         if (x <= a[a[now].ch[0]].sum) now = a[now].ch[0];
 79         else {
 80             x -= a[a[now].ch[0]].sum + 1;
 81             if (!x) return now;
 82             now = a[now].ch[1];
 83         }
 84     }
 85 }
 86 void reverse(int l, int r) {
 87     l--, r++;
 88     l = find(l);
 89     r = find(r);
 90     splay(l, 0);
 91     splay(r, l);
 92     int now = a[root].ch[1];
 93     now = a[now].ch[0];
 94     a[now].tag ^= 1;
 95 }
 96 void dfs(int now) {
 97     pushdown(now);
 98     if (a[now].ch[0]) dfs(a[now].ch[0]);
 99     if (a[now].val != INF && a[now].val != -INF) printf("%d ", a[now].val);
100     if (a[now].ch[1]) dfs(a[now].ch[1]);    
101 }
102 int main() {
103     n = read(); m = read();
104     in[1] = -INF; in[n + 2] = INF;
105     for (register int i = 1; i <= n; ++i) in[i + 1] = i;
106     root = build(0, 1, n + 2);
107     for (register int i = 1; i <= m; ++i) {
108         int x = read() + 1, y = read() + 1;
109         reverse(x, y);
110     }
111     dfs(root);
112     return 0;
113 }
AC Code

猜你喜欢

转载自www.cnblogs.com/shl-blog/p/11267488.html