线段树学习(单点更新+区间更新+区间查询)(C++模板)

一、线段树的用处

        在对一组连续的数据进行修改或者求和(求最值)操作时,线段树可以通过快速的修改子区间上的值来达成你的目标。



二、线段树是什么

        线段树是一种二叉搜索树,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。使用线段树可以快速的查找某一条线段对应的状态。

        看一副图来理解(图片魔改自百度百科):


    可见图中我们用一个节点1来储存一段[1,10]线段上的数据,将节点1对半拆开,可以得到节点2和节点3,他们分别储存的是[1,5]和[6,10]上的数据。以此类推,我们可以分出节点4、节点5……直到不可再分。



三、线段树的单点更新和查询(后附区间更新)

    这里采用的是结构体的方式建立线段树。

const int maxn = 500005 * 4;	//线段树范围要开4倍
struct Tree
{
	int l, r, sum, maxx;
};
Tree node[maxn];	//node[maxn]为线段树处理使用的数组
int a[maxn];		//a[maxn]表示读入的数据	

    结构体中的l,r表示的是该节点所覆盖的区间为[l,r]。sum表示的是该段区间上的数据总和,maxx表示该段区间上数据的最值。看到之前的那幅图,我们在将区间进行分割的时候,会出现大量的小区间,所以对于结构体的数组(即线段树)大小我们需要设定成条件所给的4倍。


    设置完了节点,那么我们怎么来建立一个我们所需要的线段树呢。再看这副图。

    

我们该怎么用代码实现连出两条线使得节点2和节点3成为节点1的子节点呢?

我们可以设当前的节点为i,那么对于图上的规律来讲,他左边的子节点编号就是i*2,右边的子节点编号就是i*2+1

知道了怎么表示子节点后,我们就要试图使用他们将叶子节点的值传递到下一层。

对于叶子节点来讲,他们的值应该是已知的,我们只需要在他们的父节点处进行更新就可以了,这是可以用递归实现的。

那么具体建树的代码就应该是这样的。

void update(int i)
{
	node[i].sum = node[i << 1].sum + node[(i << 1) | 1].sum;            //求和子节点
	node[i].maxx = max(node[i << 1].maxx, node[(i << 1) | 1].maxx);       //取子节点最值
}

void build(int i, int l, int r)
{
	node[i].l = l; node[i].r = r;
	if (l == r)            //到达了叶子节点直接赋值
	{
		node[i].maxx = a[l];
		node[i].sum = a[l];
		return;
	}
	int mid = (l + r) / 2;
	build(i << 1, l, mid);                //左节点建立
	build((i << 1) | 1, mid + 1, r);        //右节点建立
	update(i);
}

    建立了线段树后我们要对线段树上的数据进行修改与求和,该怎么操作呢?

    对于单点的修改我们找到所需要修改的点k对应的叶子节点,然后一路递归更新下去实际上在代码和build线段树是差不多的。

    让我们直接看代码来解释:

void add(int i, int k, int v)	        //当前更新的节点的编号为i(一般是以1为第一个编号)。
{					//k为需要更新的点的位置,v为修改的值的大小
	if (node[i].l == k&&node[i].r == k)        //左右端点均和k相等,说明找到了k所在的叶子节点
	{
		node[i].sum += v;
		node[i].maxx += v;
		return;    //找到了叶子节点就不需要在向下寻找了
	}
	int mid = (node[i].l + node[i].r) / 2;
	if (k <= mid) add(i << 1, k, v);
	else add((i << 1) | 1, k, v);            //寻找k所在的子区间
	update(i);        //递归更新
}

    使用这个add函数就可以实现对线段树的单点更新啦。比如使k点的值加上v,就是add(1,k,v)。

    求区间的最值代码实际上和求和是一样的,也是先找到对应区间所在的子节点,然后向下递归更新。

    求最值代码如下

    

int getmax(int i, int l, int r)
{
	if (node[i].l == l&&node[i].r == r)
		return node[i].maxx;
	int mid = (node[i].l + node[i].r) / 2;
	if (r <= mid) return getmax(i << 1, l, r);
	else if (l>mid) return getmax((i << 1) | 1, l, r);
	else return max(getmax(i << 1, l, mid), getmax((i << 1) | 1, mid + 1, r));
}

    以上。我们已经完成对于一个线段树的单点更新和查询。

   模板如下

const int maxn = 500005 * 4;	//线段树范围要开4倍
struct Tree
{
	int l, r, sum, maxx;
};
Tree node[maxn];		//node[maxn]为线段树处理数组
int a[maxn];			//a[maxn]为原数组
void update(int i)
{
	node[i].sum = node[i << 1].sum + node[(i << 1) | 1].sum;
	node[i].maxx = max(node[i << 1].maxx, node[(i << 1) | 1].maxx);
}
void build(int i, int l, int r)
{
	node[i].l = l; node[i].r = r;
	if (l == r)
	{
		node[i].maxx = a[l];
		node[i].sum = a[l];
		return;
	}
	int mid = (l + r) / 2;
	build(i << 1, l, mid);
	build((i << 1) | 1, mid + 1, r);
	update(i);
}
int getsum(int i, int l, int r)
{
	if (node[i].l == l&&node[i].r == r)
		return node[i].sum;
	int mid = (node[i].l + node[i].r) / 2;
	if (r <= mid) return getsum(i << 1, l, r);
	else if (l > mid) return getsum((i << 1) | 1, l, r);
	else return getsum(i << 1, l, mid) + getsum((i << 1) | 1, mid + 1, r);
}
int getmax(int i, int l, int r)
{
	if (node[i].l == l&&node[i].r == r)
		return node[i].maxx;
	int mid = (node[i].l + node[i].r) / 2;
	if (r <= mid) return getmax(i << 1, l, r);
	else if (l>mid) return getmax((i << 1) | 1, l, r);
	else return max(getmax(i << 1, l, mid), getmax((i << 1) | 1, mid + 1, r));
}
void add(int i, int k, int v)	        //当前更新的节点的编号为i(一般是1为初始编号,具体得看建立树时使用的第一个编号是什么)。
{								//k为需要更新的点的位置,v为修改的值的大小
	if (node[i].l == k&&node[i].r == k)        //左右端点均和k相等,说明找到了k所在的叶子节点
	{
		node[i].sum += v;
		node[i].maxx += v;
		return;    //找到了叶子节点就不需要在向下寻找了
	}
	int mid = (node[i].l + node[i].r) / 2;
	if (k <= mid) add(i << 1, k, v);
	else add((i << 1) | 1, k, v);
	update(i);
}

四、线段树的区间更新

    为什么需要把区间更新和单点更新区分开来呢?

    当我们面对给[a,b]范围上的数据都加上v,这一类的问题时,我们利用单点更新是怎么操作的呢?

    首先单点更新a+1,再更新a+2,再a+3……直到更新b。那么对于多个这样的询问,显然操作数是爆表的。所以我们需要一种巧妙的方法,降低我们更新的操作数。

    这里引入了一个标记数组,lazy[maxn<<2].

    由字面的意思,这个lazy数组就是一个给懒人使用标记。

    每当我们需要把一个区间[a,b]都加上v,现在我们其实并没有直接进入到线段树的对应区间的子区间去修改,而是先给这个区间做一个标记v,若这个区间有n个数据,当我们查询时候只需要读取区间原有的数据并且加上n*v。

    就我的理解lazy标记更像是维护了另一个树。

    简单的说就是,我们把向下的修改先储存起来,而对于每个查询我们在向上传递答案的时候加上这些修改的值。

    用代码来实现就是这样

void PushUp(int rt)
{
	tree[rt].sum = tree[rt << 1].sum + tree[rt << 1 | 1].sum;
}

void PushDown(int rt,int m)        //m表示的是rt对应的当前区间的长度
{
	if (lazy[rt])
	{
		lazy[rt << 1] += lazy[rt];           //延迟的值向左节点传递
		lazy[rt << 1 | 1] += lazy[rt];        //延迟的值向右节点传递
		tree[rt << 1].sum += lazy[rt] * (m - (m >> 1));   
		tree[rt << 1 | 1].sum += lazy[rt] * (m >> 1);
		lazy[rt] = 0;
	}
}

PushUp函数表示的是向上的更新,PushDown维护的是lazy标记延后的值。

明白了lazy的作用,就可以偷偷放出区间更新的模板啦。

const int N = 100005;
LL a[N];					//a[N]储存原数组
LL  lazy[N << 2];			//lazy用来记录该节点的每个数值应该加多少 
int n, q;
struct Tree
{
	int l, r;
	LL sum;
	int mid()
	{
		return (l + r) >> 1;
	}
}tree[N<<2];		

void PushUp(int rt)
{
	tree[rt].sum = tree[rt << 1].sum + tree[rt << 1 | 1].sum;
}

void PushDown(int rt,int m)
{
	if (lazy[rt])
	{
		lazy[rt << 1] += lazy[rt];
		lazy[rt << 1 | 1] += lazy[rt];
		tree[rt << 1].sum += lazy[rt] * (m - (m >> 1));
		tree[rt << 1 | 1].sum += lazy[rt] * (m >> 1);
		lazy[rt] = 0;
	}
}

void build(int l, int r, int rt)
{
	tree[rt].l = l;
	tree[rt].r = r;
	lazy[rt] = 0;
	if (l == r)
	{
		tree[rt].sum = a[l];
		return;
	}
	int m = tree[rt].mid();
	build(l, m, (rt << 1));
	build(m + 1, r, (rt << 1 | 1));
	PushUp(rt);
}

void update(LL c, int l, int r, int rt)
{
	if (tree[rt].l == l&&tree[rt].r==r)
	{ 
		lazy[rt] += c;
		tree[rt].sum += c*(r - l + 1);
		return;
	}
	if (tree[rt].l == tree[rt].r)return;
	int m = tree[rt].mid();
	PushDown(rt, tree[rt].r - tree[rt].l + 1);
	if (r <= m)update(c, l, r, rt << 1);
	else if (l > m)update(c, l, r, rt << 1 | 1);
	else 
	{
		update(c, l, m, rt << 1);
		update(c, m + 1, r, rt << 1 | 1);
	}
	PushUp(rt);
}

LL Query(int l, int r, int rt)
{
	if (l == tree[rt].l&&r == tree[rt].r)
	{
		return tree[rt].sum;
	}
	int m = tree[rt].mid();
	PushDown(rt, tree[rt].r - tree[rt].l + 1);
	LL res = 0;
	if (r <= m)res += Query(l, r, rt << 1);
	else if (l > m)res += Query(l, r, rt << 1 | 1);
	else
	{
		res += Query(l, m, rt << 1);
		res += Query(m + 1, r, rt << 1 | 1);
	}
	return res;
}


附上线段树模板题:

POJ3468A Simple Problem with Integers


以及AC代码

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<string>
#include<algorithm>
#include<vector>
#include<queue>
#include<set>
#include<map>
#include<stack>
#include<list>
using namespace std;
const int INF = 0x3f3f3f3f;
#define LL long long int 
long long  gcd(long long  a, long long  b) { return a == 0 ? b : gcd(b % a, a); }



const int N = 100005;
LL a[N];					
LL  lazy[N << 2];			
int n, q;
struct Tree
{
	int l, r;
	LL sum;
	int mid()
	{
		return (l + r) >> 1;
	}
}tree[N<<2];		

void PushUp(int rt)
{
	tree[rt].sum = tree[rt << 1].sum + tree[rt << 1 | 1].sum;
}

void PushDown(int rt,int m)
{
	if (lazy[rt])
	{
		lazy[rt << 1] += lazy[rt];
		lazy[rt << 1 | 1] += lazy[rt];
		tree[rt << 1].sum += lazy[rt] * (m - (m >> 1));
		tree[rt << 1 | 1].sum += lazy[rt] * (m >> 1);
		lazy[rt] = 0;
	}
}

void build(int l, int r, int rt)
{
	tree[rt].l = l;
	tree[rt].r = r;
	lazy[rt] = 0;
	if (l == r)
	{
		tree[rt].sum = a[l];
		return;
	}
	int m = tree[rt].mid();
	build(l, m, (rt << 1));
	build(m + 1, r, (rt << 1 | 1));
	PushUp(rt);
}

void update(LL c, int l, int r, int rt)
{
	if (tree[rt].l == l&&tree[rt].r==r)
	{ 
		lazy[rt] += c;
		tree[rt].sum += c*(r - l + 1);
		return;
	}
	if (tree[rt].l == tree[rt].r)return;
	int m = tree[rt].mid();
	PushDown(rt, tree[rt].r - tree[rt].l + 1);
	if (r <= m)update(c, l, r, rt << 1);
	else if (l > m)update(c, l, r, rt << 1 | 1);
	else 
	{
		update(c, l, m, rt << 1);
		update(c, m + 1, r, rt << 1 | 1);
	}
	PushUp(rt);
}

LL Query(int l, int r, int rt)
{
	if (l == tree[rt].l&&r == tree[rt].r)
	{
		return tree[rt].sum;
	}
	int m = tree[rt].mid();
	PushDown(rt, tree[rt].r - tree[rt].l + 1);
	LL res = 0;
	if (r <= m)res += Query(l, r, rt << 1);
	else if (l > m)res += Query(l, r, rt << 1 | 1);
	else
	{
		res += Query(l, m, rt << 1);
		res += Query(m + 1, r, rt << 1 | 1);
	}
	return res;
}

int main()
{
	while (scanf("%d%d", &n, &q) != EOF)
	{
		for (int i = 1; i <= n; i++)
			scanf("%lld", &a[i]);
		build(1, n, 1);
		char t;
		int a, b;
		LL c;
		while (q--)
		{
			getchar();
			scanf("%c", &t);
			if (t == 'Q')
			{
				scanf("%d %d", &a, &b);
				printf("%lld\n", Query(a, b, 1));
			}
			else if (t == 'C')
			{
				scanf("%d %d %lld", &a, &b, &c);
				update(c, a, b, 1);
			}
		}
	}
	getchar();
	getchar();
}

    

猜你喜欢

转载自blog.csdn.net/amovement/article/details/80615499