线段树
为什么使用线段树
- 区间染色
- 有一面墙,长度为n,每次选择一段墙进行染色
- M次操作后,我们可以可见多少种颜色
- M次操作后,在[i,j]区间能看见多少种颜色
涉及的操作
- 染色操作 (更新区间)
- 查询操作 (查询区间)
数组 | 线段树 | |
---|---|---|
染色操作 | O(N) | O(logN) |
查询操作 | O(N) | O(logN) |
什么是线段树
-
以求和为例,每个节点就是存储的每个区间的和
-
线段树是平衡二叉树,但不一定是满二叉树
-
堆也是平衡二叉树
-
可以把线段树看做满二叉树 : 将没有区间的看做[]
-
若区间有n个元素,数组需要
4n
的空间节点 -
使用数组实现:
我们的线段树不考虑添加元素,即区间固定,使用4n的静态空间即可
线段树的基本操作
线段树的基本操作主要包括构造线段树,区间查询和区间修改
1. 构造线段树
构造线段树是一个递归的过程
C++实现:
// 返回完全二叉树数组表示中,一个索引表示的节点的左孩子的索引
int leftChild(int index) {
return 2 * index + 1;
}
// 返回完全二叉树数组表示中,一个索引表示的节点的右孩子的索引
int rightChild(int index) {
return 2 * index + 2;
}
// 在treeIdex位置创建区间表示[l...r]的线段树
void buildSegmentTree(int treeIndex, int l, int r) {
if (l == r) {
tree[treeIndex] = data[l];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// 递归调用
int mid = l+(r-l) / 2; // (l+r)/2 会有溢出问题
buildSegmentTree(leftTreeIndex, l, mid);
buildSegmentTree(rightTreeIndex, mid + 1, r);
// 给tree[]赋值,与业务相关,以求和为例
tree[treeIndex] = tree[leftTreeIndex] + tree[rightTreeIndex];
}
2.区间查询
区间查询指的是用户指定一个区间,获得这个区间的相关信息,如区间的最大值,最小值,和等.
查询的C++代码如下:
- 查询一下是否能够找到对应的区间
- 若没有向下继续遍历
// 线段树查询
T query(int queryL, int queryR) {
if (queryR >= 0 && queryR < data.size()
&& queryL >= 0 && queryL < data.size()
&& queryR> queryL) {
return query(0,0,data.size()-1,queryL,queryR);
}
}
T query(int treeIndex, int l, int r, int queryL, int queryR) {
if (l == queryL&&r == queryR) {
return tree[treeIndex];
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l + (r - l) / 2;
if (queryL > mid) {
return query(rightTreeIndex, mid + 1, r, queryL, queryR);
} else if (queryR <= mid) {
return query(leftTreeIndex, l, mid, queryL, queryR);
}
// [queryL..mid] + [mid+1,queryR]
T leftRes = query(leftTreeIndex, l, mid, queryL, mid);
T rightRes = query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
return leftRes + rightRes;
}
3.区间更新
3.1 更新某个叶节点:
// 将index位置的值,更新为e
void set(int index, T val) {
if (index >= 0 && index < data.size()) {
data[index] = e;
// 更新线段树,叶子节点
set(0, 0, data.size() - 1, index, val);
}
}
void set(int treeIndex, int l, int r, int index, T val) {
if (l == r) {
tree[treeIndex] = val;
return;
}
// 去找线段树中叶子节点位置的索引在哪里
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l + (r - l) / 2;
if (index <= mid) {
set(leftTreeIndex, l, mid, index, val);
} else {
set(rightTreeIndex, mid + 1, r, index, val);
}
// 更新过程
tree[treeIndex] = tree[leftTreeIndex] + tree[rightTreeIndex];
}
3.2 更新某个区间:
采用懒惰更新,延迟标记
待续
LeetCode 307
class NumArray {
public:
vector<int> data;
vector<int> tree;
int size;
// 1. 构建线段树
void buildSegmentTree(int treeIdx, int l, int r) {
if (l == r) {
tree[treeIdx] = data[l];
return;
}
int mid = l + (r - l) / 2;
int leftChildTreeIdx = treeIdx * 2+1;
int rightChildTreeIdx = treeIdx * 2 + 2;
buildSegmentTree(leftChildTreeIdx, l, mid);
buildSegmentTree(rightChildTreeIdx, mid + 1, r);
tree[treeIdx] = tree[leftChildTreeIdx] + tree[rightChildTreeIdx];
}
// 2.查询线段树
int query(int treeIdx, int l, int r, int queryL, int queryR) {
if (l == queryL&&r == queryR) {
return tree[treeIdx];
}
int mid = l + (r - l) / 2;
int leftChildTreeIdx = treeIdx * 2+1;
int rightChildTreeIdx = treeIdx * 2 + 2;
if (queryR <= mid) {
return query(leftChildTreeIdx, l, mid, queryL, queryR);
}
else if (queryL > mid) {
return query(rightChildTreeIdx, mid+1, r, queryL, queryR);
}
int leftResult = query(leftChildTreeIdx, l, mid, queryL, mid);
int rightResult = query(rightChildTreeIdx, mid + 1, r, mid+1, queryR);
return leftResult + rightResult;
}
int query(int queryL, int queryR) {
if (queryL >= 0 && queryL < size
&& queryR >= 0 && queryR < size
&& queryR >= queryL) {
return query(0, 0, size - 1, queryL, queryR);
}
return 0;
}
// 3.设置某个单个节点的值
void set(int treeIdx, int l, int r, int index, int val) {
if (l == r) {
tree[treeIdx] = val;
return;
}
int mid = l + (r - l) / 2;
int leftChildTreeIdx = treeIdx * 2+1;
int rightChildTreeIdx = treeIdx * 2 + 2;
if (index <= mid) {
set(leftChildTreeIdx, l, mid, index, val);
}
else {
set(rightChildTreeIdx, mid + 1,r, index, val);
}
tree[treeIdx] = tree[leftChildTreeIdx] + tree[rightChildTreeIdx];
}
void set(int index, int val) {
if (index >= 0 && index <size) {
data[index] = val;
set(0, 0, size - 1, index, val);
}
}
// 下面是题目的:
NumArray(vector<int> nums) {
// 1. 先拷贝整个数组
if (nums.size() == 0) {
return;
}
size = nums.size();
data = vector<int>(nums.begin(), nums.end());
// 2. 创建这个线段树
tree.resize(4 * size);
buildSegmentTree(0, 0, size - 1);
}
void update(int i, int val) {
set(i, val);
}
int sumRange(int i, int j) {
return query(i, j);
}
};