主席树简介及代码模板

主席树全称是可持久化权值线段树,通常用来解决区间第 k 小、区间众数的问题。

可持久化数据结构 (Persistent data structure) 总是可以保留每一个历史版本,并且支持操作的不可变特性 (immutable)。

权值线段树是记录每个权值出现次数的一种线段树。

为什么要使用主席树?

如果我们要求区间第 k 小的数,如果暴力一点,可以每插入一次开一棵线段树,但是空间一定会爆掉。所以要使用主席树,在原有空间的基础上,进行插入。

主席树的主要思想

保存每次插入操作时的历史版本,以便查询区间第 k 小。

主席树的基本操作:

1、build,建树

2、insert,插入值

3、query,查询区间的第 k 小

主席树的基本结构:

struct Node {
    
    
	int l, r, sum;
} tr[N * 40];

int n, q, m, idx; // idx 为每个数字的下标
int root[N], a[N], b[N]; // root 记录每个数字的根节点,a 为原数组,b 为离散化后的数组

建树操作:

以 l 为左端点,以 r 为右端点,建树

int build(int l, int r)
{
    
    
	int p = ++idx; // 该子树的根节点 
	if (l == r) return p;
	int mid = l + r >> 1;
	tr[p].l = build(l, mid);
	tr[p].r = build(mid + 1, r);
	
	return p;
}

插入操作:

在 pre 根节点的基础上,插入 x

int insert(int pre, int l, int r, int x)
{
    
    
	int p = ++idx;
	tr[p] = tr[pre];
	if (l == r) {
    
    
		tr[p].sum++;
		return p;
	}
	int mid = l + r >> 1;
	if (x <= mid) 
		tr[p].l = insert(tr[p].l, l, mid, x);
	else
		tr[p].r = insert(tr[p].r, mid + 1, r, x);
	tr[p].sum = tr[tr[p].l].sum + tr[tr[p].r].sum;
	
	return p;
}

查询操作 / 求区间第 k 小:

如果我们要求区间 [1, r] 的第 k 小,只需要找到插入 r 时的根节点版本,再利用权值线段树即可。

同理,如果要求的是区间 [l, r] 的第 k 小,可以利用前缀和的思想,用 [1, r] 的信息减去 [1, l - 1] 的信息即可。

用 q、p 两节点作差,得到左儿子的信息,进而判断第 k 小的数在哪个儿子中

int query(int p, int q, int l, int r, int k)
{
    
    
	if (l == r) return r;
	int x = tr[tr[q].l].sum - tr[tr[p].l].sum; // 相减得到左儿子的信息 
	int mid = l + r >> 1;
	if (k <= x) // 第 K 小的数在左儿子中 
		return query(tr[p].l, tr[q].l, l, mid, k);
	else // 第 K 小的数在右儿子中 
		return query(tr[p].r, tr[q].r, mid + 1, r, k - x);
}

基本用法:

1、主席树求区间第 k 小(代码带有离散化):

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;

struct Node {
    
    
	int l, r, sum;
} tr[N * 40];

int n, q, m, idx;
int root[N], a[N], b[N];

int build(int l, int r)
{
    
    
	int p = ++idx; // 该子树的根节点 
	if (l == r) return p;
	int mid = l + r >> 1;
	tr[p].l = build(l, mid);
	tr[p].r = build(mid + 1, r);
	
	return p;
}

int insert(int pre, int l, int r, int x)
{
    
    
	int p = ++idx;
	tr[p] = tr[pre];
	if (l == r) {
    
    
		tr[p].sum++;
		return p;
	}
	int mid = l + r >> 1;
	if (x <= mid) 
		tr[p].l = insert(tr[p].l, l, mid, x);
	else
		tr[p].r = insert(tr[p].r, mid + 1, r, x);
	tr[p].sum = tr[tr[p].l].sum + tr[tr[p].r].sum;
	
	return p;
}

int query(int p, int q, int l, int r, int k)
{
    
    
	if (l == r) return r;
	int x = tr[tr[q].l].sum - tr[tr[p].l].sum; // 相减得到左儿子的信息 
	int mid = l + r >> 1;
	if (k <= x) // 第 K 小的数在左儿子中 
		return query(tr[p].l, tr[q].l, l, mid, k);
	else // 第 K 小的数在右儿子中 
		return query(tr[p].r, tr[q].r, mid + 1, r, k - x);
}

int main(void)
{
    
    
	scanf("%d%d", &n, &q);
	for (int i = 1; i <= n; i++) {
    
    
		scanf("%d", &a[i]);
		b[i] = a[i];
	}
	sort(b + 1, b + n + 1);
	m = unique(b + 1, b + n + 1) - b - 1;
	
	root[0] = build(1, m);
	for (int i = 1; i <= n; i++) {
    
    
		int t = lower_bound(b + 1, b + m + 1, a[i]) - b;
		root[i] = insert(root[i - 1], 1, m, t);
	}
	
	int l, r, k;
	while (q--) {
    
    
		scanf("%d%d%d", &l, &r, &k);
		int t = query(root[l - 1], root[r], 1, m, k);
		printf("%d\n", b[t]);
	}
	
	return 0;
}

2、主席树求区间众数(区间内出现次数大于(l - r + 1) / 2的数):

#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;

struct Node {
    
    
	int l, r, sum;
} tr[N * 40];

int n, q, m, idx;
int root[N], a[N], b[N];

int build(int l, int r)
{
    
    
	int p = ++idx; // 该子树的根节点 
	if (l == r) return p;
	int mid = l + r >> 1;
	tr[p].l = build(l, mid);
	tr[p].r = build(mid + 1, r);
	
	return p;
}

int insert(int pre, int l, int r, int x)
{
    
    
	int p = ++idx;
	tr[p] = tr[pre];
	if (l == r) {
    
    
		tr[p].sum++;
		return p;
	}
	int mid = l + r >> 1;
	if (x <= mid) 
		tr[p].l = insert(tr[p].l, l, mid, x);
	else
		tr[p].r = insert(tr[p].r, mid + 1, r, x);
	tr[p].sum = tr[tr[p].l].sum + tr[tr[p].r].sum;
	
	return p;
}

int query(int p, int q, int l, int r, int k)
{
    
    
	if (l == r) return r;
	int x = tr[tr[q].l].sum - tr[tr[p].l].sum; // 相减得到左儿子的信息 
	int y = tr[tr[q].r].sum - tr[tr[p].r].sum; // 右儿子的信息 
	int mid = l + r >> 1;
	if (x > k) // 众数在左儿子中 
		return query(tr[p].l, tr[q].l, l, mid, k);
	else if (y > k) // 众数在右儿子中 
		return query(tr[p].r, tr[q].r, mid + 1, r, k);
	else // 不存在众数
		return 0;
}

int main(void)
{
    
    
	scanf("%d%d", &n, &q);
	for (int i = 1; i <= n; i++) {
    
    
		scanf("%d", &a[i]);
		b[i] = a[i];
	}
	sort(b + 1, b + n + 1);
	m = unique(b + 1, b + n + 1) - b - 1;
	
	root[0] = build(1, m);
	for (int i = 1; i <= n; i++) {
    
    
		int t = lower_bound(b + 1, b + m + 1, a[i]) - b;
		root[i] = insert(root[i - 1], 1, m, t);
	}
	
	int l, r, k;
	while (q--) {
    
    
		scanf("%d%d", &l, &r);
		int t = query(root[l - 1], root[r], 1, m, (r - l + 1) / 2);
		printf("%d\n", b[t]);
	}
	
	return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_43772166/article/details/109184508