Getting to know the line segment tree

Getting to know the line segment tree

The line segment tree is a binary search tree , similar to the interval tree , which divides an interval into some unit intervals, and each unit interval corresponds to a leaf node in the line segment tree.

Using the line segment tree can quickly find the number of times a certain node appears in several line segments, and the time complexity is O(logN). The unoptimized space complexity is 2N. In practical applications, 4N arrays are generally required to avoid crossing the boundary. Therefore, discretization is sometimes required to compress the space.

Question 1:
Now there are 100,000 positive integers, numbered from 1 to 100,000.

Now given an interval [L, R].

Find the sum of the interval L to R

Method 1: directly for(int i=L;i<=R;i++) to traverse 100,000 numbers and add them all up

Method 2: Simplify the calculation by obtaining the prefix sum, and the prefix sum array is B[100010], then the result is B[R]-B[L-1] is the result, it is not difficult to see that method 2 is faster than method 1

Topic 2:

Now there are 100000 positive integers numbered from 1 to 100000.

Now given an interval [L, R] and a positive integer k, c.

After adding c to the kth number, find the sum of the interval L to R

If you continue to use method one, its time complexity will not change.

But for method 2, after adding a number, its prefix and array will change. If k=10, then the prefix and all of the entire interval [10,100000] need to be modified, which will Greatly slows down calculations

From the two examples above, it can be seen that

Method 1: Summation is slow, but modification is fast

Method 2: Summation is fast, but summation is slow

So is there a way to take into account the advantages of these two methods? The summation and modification are fast. This is the line segment tree to be introduced in this article . The time complexity of inserting the number of line segments is logN

Segment tree division

The line segment tree is a binary tree. After an interval [L, R] is given, we continue to divide the interval evenly until L==R.

segment tree

How to define a segment tree

As can be seen from the figure, the line segment tree is composed of many intervals, and each interval records the 左端点sum of the interval 右端点and the sum of the values ​​​​in the interval, so we need to define a structure

struct node
{
	int l, r;
	int sum;
}tr[4*N];

The size of the array needs to be quadrupled, the reason is not proven, just remember it first

How to calculate the value of each interval?

bottom-up calculation

From the leaf nodes of the line segment tree (only its own nodes), for example, the interval [1,2] can be calculated by node[i].l+node[i].r(1+2).

Calculated from bottom to top.

void push_up(int u)
{
	tr[u].sum = tr[2 * u].sum + tr[2 * u + 1].sum;//2*u为左儿子,2*u+1为右儿子
}

How to build a line segment tree?

void build(int u, int l, int r)
{
	if (l == r) tr[u] = { l,r ,w[l]};//如果达到了叶子节点,就赋值
	else
	{
		tr[u] = { l,r };//没有到达叶子节点,就先记录下当前区间的左端点和右端点
		int mid = l + r >> 1;//将区间平分
		build(2 * u, l, mid);//递归左儿子
		build(2 * u + 1, mid + 1, r);//递归右儿子
		push_up(u);//回溯的时候依次通过左右儿子算得sum
	}
}

How to modify a value?

void modify(int u, int x, int v)
{
	if (tr[u].l == tr[u].r)//递归到了叶子节点的时候
	{
		tr[u].sum += v;
		return;
	}
	else
	{
		int mid = (tr[u].l + tr[u].r) / 2;
		if (x <= mid) modify(u * 2, x, v);//如果当前序列在左边,那么就递归左区间
		else modify(u * 2 + 1, x, v);//在右边就递归右区间

		push_up(u);//修改了之后,还要需要修改一些节点的值,重新自下而上计算
	}
}

How to find the sum of a certain interval?

interval sum

The intervals that need to be designed are [4], [5,6], [7,8], [9,10], [11].

int query(int u, int l, int r)
{
	//需要累加所有在这个范围内的区间
	if (l <= tr[u].l && r >= tr[u].r) return tr[u].sum;
	//否则的话就需要递归计算
	int mid = (tr[u].l + tr[u].r) / 2;
	int sum = 0;
	if (mid >= l)  sum += query(u*2, l, r);//如果左区间和要求的区间有交集,那么递归左区间
	if (r >= mid + 1) sum += query(u * 2 + 1, l, r);//如果右区间和要求的区间有交集,那么递归右区间

	return sum;
}

Classic example:

example

AC code:

#include<iostream>
using namespace std;
const int N = 100010;
int n, m;
int w[N];//权值

//定义线段树节点
struct node
{
	int l, r;
	int sum;
}tr[4*N];//要开四倍大小

//向上累加
void push_up(int u)
{
	tr[u].sum = tr[2 * u].sum + tr[2 * u + 1].sum;
}

//建树
void build(int u, int l, int r)
{
	if (l == r) tr[u] = { l,r ,w[l]};//如果达到了叶子节点,就赋值
	else
	{
		tr[u] = { l,r };//没有到达叶子节点,就先记录下当前区间的左端点和右端点
		int mid = l + r >> 1;//将区间平分
		build(2 * u, l, mid);//递归左儿子
		build(2 * u + 1, mid + 1, r);//递归右儿子
		push_up(u);//回溯的时候依次通过左右儿子算得sum
	}
}

//区间查询
int query(int u, int l, int r)
{
	//需要累加所有在这个范围内的区间
	if (l <= tr[u].l && r >= tr[u].r) return tr[u].sum;
	//否则的话就需要递归计算
	int mid = (tr[u].l + tr[u].r) / 2;
	int sum = 0;
	if (mid >= l)  sum += query(u*2, l, r);//如果左区间和要求的区间有交集,那么递归左区间
	if (r >= mid + 1) sum += query(u * 2 + 1, l, r);//如果右区间和要求的区间有交集,那么递归右区间

	return sum;
}

//修改
void modify(int u, int x, int v)
{
	if (tr[u].l == tr[u].r)//递归到了叶子节点的时候
	{
		tr[u].sum += v;
		return;
	}
	else
	{
		int mid = (tr[u].l + tr[u].r) / 2;
		if (x <= mid) modify(u * 2, x, v);//如果当前序列在左边,那么就递归左区间
		else modify(u * 2 + 1, x, v);//在右边就递归右区间

		push_up(u);//修改了之后,还要需要修改一些节点的值,重新自下而上计算
	}
}

int main(void)
{
	cin >> n >> m;
	for (int i = 1; i <= n; i++) scanf("%d", &w[i]);

	build(1, 1, n);

	while (m--)
	{
		int k, a, b;
		cin >> k >> a >> b;
		if (k == 0) cout << query(1, a, b) << endl;
		else
		{
			modify(1, a, b);
		}
	}
	return 0;
}

example

Without full AC code (too slow):

#include<iostream>
#include<algorithm>
using namespace std;
const int N = 100010;
int w[N];
int n, m;
struct node
{
	int l,r;
	int maxv;
}tr[N*4];

void push_up(int u)
{
	tr[u].maxv = max(tr[u * 2].maxv, tr[u * 2 + 1].maxv);
}

void build(int u, int l, int r)
{
	if (l == r)
	{
		tr[u] = { l,r,w[l] };
		return;
	}
	else
	{
		tr[u] = { l,r};
		int mid = (l + r) >> 1;
		build(u * 2, l, mid);
		build(u * 2 + 1, mid+1, r);
		push_up(u);
	}
}

int query(int u, int l, int r)
{
	if (tr[u].l >= l && tr[u].r <= r) return tr[u].maxv;
	int mid = (tr[u].l + tr[u].r) / 2;
	int maxv = -10000000;
	if (l <= mid) maxv = max(maxv, query(u * 2, l, r));
	if (r > mid + 1) maxv = max(maxv, query(u * 2 + 1, l, r));
	return maxv;

}

int main()
{
	int l, r;
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; ++i)   scanf("%d", &w[i]);
	build(1, 1, n);
	while (m--) {
		scanf("%d %d", &l, &r);
		printf("%d\n", query(1, l, r));
	}
	return 0;
}

Guess you like

Origin blog.csdn.net/AkieMo/article/details/129976475