主席树全称是可持久化权值线段树,通常用来解决区间第 k 小、区间众数的问题。
可持久化数据结构 (Persistent data structure) 总是可以保留每一个历史版本,并且支持操作的不可变特性 (immutable)。
权值线段树是记录每个权值出现次数的一种线段树。
为什么要使用主席树?
如果我们要求区间第 k 小的数,如果暴力一点,可以每插入一次开一棵线段树,但是空间一定会爆掉。所以要使用主席树,在原有空间的基础上,进行插入。
主席树的主要思想:
保存每次插入操作时的历史版本,以便查询区间第 k 小。
主席树的基本操作:
1、build,建树
2、insert,插入值
3、query,查询区间的第 k 小
主席树的基本结构:
struct Node {
int l, r, sum;
} tr[N * 40];
int n, q, m, idx; // idx 为每个数字的下标
int root[N], a[N], b[N]; // root 记录每个数字的根节点,a 为原数组,b 为离散化后的数组
建树操作:
以 l 为左端点,以 r 为右端点,建树
int build(int l, int r)
{
int p = ++idx; // 该子树的根节点
if (l == r) return p;
int mid = l + r >> 1;
tr[p].l = build(l, mid);
tr[p].r = build(mid + 1, r);
return p;
}
插入操作:
在 pre 根节点的基础上,插入 x
int insert(int pre, int l, int r, int x)
{
int p = ++idx;
tr[p] = tr[pre];
if (l == r) {
tr[p].sum++;
return p;
}
int mid = l + r >> 1;
if (x <= mid)
tr[p].l = insert(tr[p].l, l, mid, x);
else
tr[p].r = insert(tr[p].r, mid + 1, r, x);
tr[p].sum = tr[tr[p].l].sum + tr[tr[p].r].sum;
return p;
}
查询操作 / 求区间第 k 小:
如果我们要求区间 [1, r] 的第 k 小,只需要找到插入 r 时的根节点版本,再利用权值线段树即可。
同理,如果要求的是区间 [l, r] 的第 k 小,可以利用前缀和的思想,用 [1, r] 的信息减去 [1, l - 1] 的信息即可。
用 q、p 两节点作差,得到左儿子的信息,进而判断第 k 小的数在哪个儿子中
int query(int p, int q, int l, int r, int k)
{
if (l == r) return r;
int x = tr[tr[q].l].sum - tr[tr[p].l].sum; // 相减得到左儿子的信息
int mid = l + r >> 1;
if (k <= x) // 第 K 小的数在左儿子中
return query(tr[p].l, tr[q].l, l, mid, k);
else // 第 K 小的数在右儿子中
return query(tr[p].r, tr[q].r, mid + 1, r, k - x);
}
基本用法:
1、主席树求区间第 k 小(代码带有离散化):
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
struct Node {
int l, r, sum;
} tr[N * 40];
int n, q, m, idx;
int root[N], a[N], b[N];
int build(int l, int r)
{
int p = ++idx; // 该子树的根节点
if (l == r) return p;
int mid = l + r >> 1;
tr[p].l = build(l, mid);
tr[p].r = build(mid + 1, r);
return p;
}
int insert(int pre, int l, int r, int x)
{
int p = ++idx;
tr[p] = tr[pre];
if (l == r) {
tr[p].sum++;
return p;
}
int mid = l + r >> 1;
if (x <= mid)
tr[p].l = insert(tr[p].l, l, mid, x);
else
tr[p].r = insert(tr[p].r, mid + 1, r, x);
tr[p].sum = tr[tr[p].l].sum + tr[tr[p].r].sum;
return p;
}
int query(int p, int q, int l, int r, int k)
{
if (l == r) return r;
int x = tr[tr[q].l].sum - tr[tr[p].l].sum; // 相减得到左儿子的信息
int mid = l + r >> 1;
if (k <= x) // 第 K 小的数在左儿子中
return query(tr[p].l, tr[q].l, l, mid, k);
else // 第 K 小的数在右儿子中
return query(tr[p].r, tr[q].r, mid + 1, r, k - x);
}
int main(void)
{
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
b[i] = a[i];
}
sort(b + 1, b + n + 1);
m = unique(b + 1, b + n + 1) - b - 1;
root[0] = build(1, m);
for (int i = 1; i <= n; i++) {
int t = lower_bound(b + 1, b + m + 1, a[i]) - b;
root[i] = insert(root[i - 1], 1, m, t);
}
int l, r, k;
while (q--) {
scanf("%d%d%d", &l, &r, &k);
int t = query(root[l - 1], root[r], 1, m, k);
printf("%d\n", b[t]);
}
return 0;
}
2、主席树求区间众数(区间内出现次数大于(l - r + 1) / 2的数):
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;
struct Node {
int l, r, sum;
} tr[N * 40];
int n, q, m, idx;
int root[N], a[N], b[N];
int build(int l, int r)
{
int p = ++idx; // 该子树的根节点
if (l == r) return p;
int mid = l + r >> 1;
tr[p].l = build(l, mid);
tr[p].r = build(mid + 1, r);
return p;
}
int insert(int pre, int l, int r, int x)
{
int p = ++idx;
tr[p] = tr[pre];
if (l == r) {
tr[p].sum++;
return p;
}
int mid = l + r >> 1;
if (x <= mid)
tr[p].l = insert(tr[p].l, l, mid, x);
else
tr[p].r = insert(tr[p].r, mid + 1, r, x);
tr[p].sum = tr[tr[p].l].sum + tr[tr[p].r].sum;
return p;
}
int query(int p, int q, int l, int r, int k)
{
if (l == r) return r;
int x = tr[tr[q].l].sum - tr[tr[p].l].sum; // 相减得到左儿子的信息
int y = tr[tr[q].r].sum - tr[tr[p].r].sum; // 右儿子的信息
int mid = l + r >> 1;
if (x > k) // 众数在左儿子中
return query(tr[p].l, tr[q].l, l, mid, k);
else if (y > k) // 众数在右儿子中
return query(tr[p].r, tr[q].r, mid + 1, r, k);
else // 不存在众数
return 0;
}
int main(void)
{
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
b[i] = a[i];
}
sort(b + 1, b + n + 1);
m = unique(b + 1, b + n + 1) - b - 1;
root[0] = build(1, m);
for (int i = 1; i <= n; i++) {
int t = lower_bound(b + 1, b + m + 1, a[i]) - b;
root[i] = insert(root[i - 1], 1, m, t);
}
int l, r, k;
while (q--) {
scanf("%d%d", &l, &r);
int t = query(root[l - 1], root[r], 1, m, (r - l + 1) / 2);
printf("%d\n", b[t]);
}
return 0;
}