竞赛线段树入门

说句实话,作为一个刚入坑的OIER,自认为天赋算非常差了,所以说开个博客记录一下自己学的东西,如果说能把刚学到的东西用最简单的方式讲解一下。如果可以让一些同样可能相对比较笨拙一些的人清晰地明白(dalao请无视这句话),那就算是我这个阶段的内容掌握的还不错吧。第一篇博客的开头就到这里,现在我们就开始来对线段树的基础做一些了解。

我们先来看一道非常简单的题目,给出一串长度为n的数列,并给出N次询问,每次询问有两种操作,modify和query,modify是对于数列中的某个数作更改,query为数列中区间[l,r]中的最大值、最小值和这个区间内所有数的和,该怎么做?
最简单的方法是暴力查询,修改操作的复杂度是O(1)的,而查询操作的话,从数列中的l开始查找,并且每读入一个数就更新一便最值和区间和,这样的话时间复杂度是O(Nn)的,正常来说出题人不脑抽的话,随便卡一下这道题就挂了。

同样的问题有区间修改,单点查询;区间修改、区间查询;单点修改、单点查询等等等等,但是如果使用暴力算法的话,时间复杂度都一样,而且都非常高。
所以在这种题上,我们就需要用到线段树的有关知识了(当然,树状数组和ST表也是可行的,但是线段树的应用范围和拓展空间比这两种数据结构多得多,因为在dfs序、树链剖分等其它数据结构里线段树都可以有多方面的应用。以后如果有空的话,我也会把这些内容都补充一下)。
那么线段树到底是什么东西呢?其实线段树就是一种二分搜索树,它把一个线性的数列结构拆分成许多个小的单元区间,然后每个节点内储存的是对应单元区间的信息,从而实现快速查询、修改等各种操作。我给大家一张图来演示一下吧

就拿图中的1~10举例子,我们可以用二分的方式,把1,10的区间划分成很多小区间(大家可以看到叶子节点上的区间都只有一个值,储存的便是数列上对应的单点的值),而如果我们要查询[2,5]的区间上的信息,我们就可以把[2,5]拆分成[2,2]+[3,3]+[4,5],然后将这三个区间需要查找的信息合并,就可以轻松得到最后的解。
这里需要注意,虽然图上的每个节点都是区间,但是实际意义上我们的数组/指针是不会去储存区间本身的,它们只会储存自己区间的信息,如这个区间的最大值、最小值等等等等,所以在实际进行查询等一系列操作的时候,我们都是用变量来划分区间,然后找到该区间对应的节点来进行操作的。

在这里,我们可以用二叉树的基本性质,用层序遍历来对于节点进行标号,对于一个序号为o的节点,左儿子编号是o*2,右儿子的编号则是o*2+1,不知道的同学们见以下这张图:

对于8号节点,也就是[1,2]这个区间来说,它的左儿子编号为16,右儿子编号为17(这些都是二叉树的基本性质,要证明的话非常简单,我就不加以赘述了)。
所以说,对于一颗线段树来说,最基本的操作有建树(废话,不建数我们怎么查询?)、查询(各种信息)、修改(最基本的涉及一种修改,比如说把一个 数列上加上/减去一个值或者说把一个 数列上
那么我们来看一看具体的代码实现(我分了指针和数组版本,个人推荐使用指针版本,因为竞赛的内容我们需要追求更快,指针+读入优化可以让你畅享飞一般的感觉)

首先是建立一颗树(我们假设给定的数组是a[100000])(指针版本)

#include<bits/stdc++.h>//考试的时候不要用这个库,到处都是bug
using namespace std;
const int MAXN=10000+5;//多定义一丢丢防止访问无效内存,但是不要定义得太大,尤其是在用数组的时候,会栈溢出的 
struct Node
{
	int sum;  //维护区间和 ; 
	int vmax; //维护区间最大值,这里注意一下,变量的名字不要定义成max或者min ,和函数重名的话在评测过程里可能会是0分 
	int vmin;  // 维护区间最小值 
	Node *ls;  //用指针的话就不需要考虑节点编号了,因为可以直接用左右儿子来进行 
	Node *rs;  //同上 
	void upate() 
	{
		sum=ls->sum+rs->sum;// 这个步骤是用于维护区间的信息的,将下面的区间的信息向上传递
		vmax=max(ls->vmax,rs->vmax);
		vmin=min(ls->vmin,rs->vmin); 
	}
}pool[N<<1],*root,*tail=pool;//由于节点总数是2*N-1,所以说我们的pool的大小是2*N; 

Node *build(int l,int r)// 使用数组的时候build函数是void,因为数组的根节点默认是1,但是在使用指针时我们必须要知道根节点 
{
	Node *nd=++tail;//定义当前节点; 
	if(l==r)// 此时已经到达单点,没有办法再继续二分了 
	{
		nd->sum=nd->vmax=nd->vmin=a[l];// 更新最值 
	} 
	else 
	{
		int mid=(l+r)>>1;//位运算操作符,效果等同于(l+r)/2
		nd->ls=build(l,mid)//更新左右儿子 
		nd->rs=build(mid+1,r);
		nd->update();
	} 
	return nd;//把当前位置的值返回上去 
}
好的,像这样操作的话建立一颗树的操作就差不多了

当然,我们也有数组的操作

void update(int o)
{
	sum[o]=sum[o<<1]+sum[o<<1|1];
	vmax[o]=max(vmax[o<<1],vmax[o<<1|1]);
	vmin[o]=min(vmin[o<<1],vmin[o<<1|1]);
}
void build(int o,int l,int r)//在实际操作的时候,一般是build(1,1,n),根节点默认是1 
{
	if(l==r)
	{
		sum[o]=a[l];
		vmax[o]=a[l];
		vmin[o]=a[l];//当然在这里也可以定义一个结构体,储存三个信息 
	}
	int mid=(l+r)>>1;
	build(o<<1,l,mid);
	build(o<<1|1,mid+1,r);
	update(o); 
}

对于查询操作,我们完全可以只讲区间查询,因为单点查询可以转换成区间查询(左右区间相同便是查询它自己),
而对于区间查询的过程,以上已经讲得差不多了,我们仍然是看代码实现(注意,截至目前为止,还没有涉及到或者是单点的修改操作)
long long query(Node *nd,int l,int r,const int L,const int R)//查询[l,r]之间的和,最大值、最小值类推即可
{
    if(L<=l&&R>=r)
    {
        return nd->sum;
    }
    int mid=(l+r)>>1;
    long long ans=0;
    if(L<=mid)
    {
        ans+=query(nd->ls,l,mid,L,R);
    }
    if(mid<R)
    {
        ans+=query(nd->rs,mid+1,r,L,R);
    }
    return ans;
}

以上,便是对于一颗线段树来说的最基本的操作,建树和区间/单点查询。
然后,我们来涉及一些稍微复杂一点的操作,首先还是从它们中最简单的开始吧,也就是最开始那道题里的单点修改
void modify(Node *nd,int l,int r,const int pos,const int delta)
{
	if(l==r)
	{
		nd->sum=delta;//如果是单点加上delta,则此操作改为nd->sum+=delta' 
		nd->vmin=delta;//如果是加上delta,则改为nd->vmin+=delta; 
		nd->vmax=delta;//加上delta改为 nd->vmax+=delta; 
		return;
	}
	int mid=(l+r)>>1;
	if(pos<=mid)
	{
		modify(nd->ls,l,mid,pos,delta);
	}
	else 
	{
		modify(nd->rs,mid+1,pos,delta);
	}
	nd->update();
}

可以看到,单点修改就是二分查找,找到这个点就改,没有找就对比过后继续递归,这个操作相当的简单,但是当操作变成区间修改的时候,情况就有些不太相同了
我们如果说要修改某一个区间内所有数的值,就会陷入一个窘境

在这张图中,如果我们要修改1~5这个区间的值,那么我们会发现,如果依然按照刚才的方法操作的话,在modify1~5这个区间的时候,完成了对于区间的修改,然后便返回。 但是对于1~5所有儿子的信息,我们是没有进行修改的。所以在这个基础之上,我们需要一种操作来实现对于某个区间所有儿子的信息的修改,我们一般把这个操作叫做pushdown
我们先来看一看代码(仍然是数组指针两种方法):
struct Node
{
    long long sum;
    Node *ls;
    Node *rs;
    bool flag=0;
    long long tag;
    void update()
    {
        sum=ls->sum+rs->sum;
    }
    void pushdown(int l,int r)
    {
        if(flag)
        {
            int mid=(l+r)>>1;
            ls->tag+=tag;
            rs->tag+=tag;
            ls->sum+=tag*(mid-l+1);
            rs->sum+=tag*(r-mid);
            ls->flag=rs->flag=1;
            flag=0;
            tag=0;
        }
    }
}pool[N*2],*tail=pool,*root;
void pushdown(int o,int l,int r){
	if(flag[o])
	{
		int mid=(l+r)>>1;
		sum[o*2]+=tag[o]*(mid-l+1);
		sum[o*2+1]+=tag[o]*(r-mid);
		tag[o*2]+=tag[o];
		tag[o*2+1]+=tag[o];
		flag[o*2]=flag[o*2+1]=1;
		tag[o]=flag[o]=0;
	}
}

我们在这里举的pushdown的例子是维护区间和的,至于区间的其它信息都可以类推。我们结合代码来看一看一些重要信息。
首先可以看到flag跟tag两个变量,flag是一个下放标记,当一个点的flag等于1的时候,表示一种“ 我想要修改这个节点,但是目前为止还没有进行这个操作的状态”,而tag表示对应节点的需要进行修改操作的值,我们每一次修改一个区间,就将这个区间对应节点的儿子节点打上下放标记(其实就是懒得改,laze思想),并且在 查询操作的时候对这些内容进行修改。需要注意的一点是关于sum的更新,我们修改一个区间(比如说把这个区间加上一个数),那么这个区间的最大值最小值会加上这个数,而这个区间的和则是加上(这个数)*(区间长度)(其实仔细想想还是很好理解的)。
同时,因为标记和laze思想的存在,我们便需要对于modify和query两个方面的操作做一些修改(其实就是加上pushdown的操作)

代码实现(指针):

void modify(Node *nd,int l,int r,const int L,const int R,const long long delta)
{
    if(L<=l&&R>=r)
    {
        nd->sum+=delta*(r-l+1);
        nd->tag+=delta;
        nd->flag=1;
        return;
    }
    nd->pushdown(l,r);
    int mid=(l+r)>>1;
    if(L<=mid)
    {
        modify(nd->ls,l,mid,L,R,delta);
    }
    if(R>mid)
    {
        modify(nd->rs,mid+1,r,L,R,delta);
    }
    nd->update();
}
long long query(Node *nd,int l,int r,const int L,const int R)
{
    if(L<=l&&R>=r)
    {
        return nd->sum;
    }
    int mid=(l+r)>>1;
    nd->pushdown(l,r);
    long long ans=0;
    if(L<=mid)
    {
        ans+=query(nd->ls,l,mid,L,R);
    }
    if(mid<R)
    {
        ans+=query(nd->rs,mid+1,r,L,R);
    }
    return ans;
}

数组版(好吧这其实是我个人的问题,因为对于初学者来说指针数组的切换一般不熟练,所以我把两个版本都写一遍):


void modify(int o,int l,int r,const int L,const int R,const long long val)
{
	if(L<=l && r<=R){
		tar[o]+=val;
		sum[o]+=val*(r-l+1);
		flag[o]=1;
		return ;
	}
	pushdown(o,l,r);
	int mid=(l+r)>>1;
	if(mid>=L) modify(o*2,l,mid,L,R,val);
	if(mid<R) modify(o*2+1,mid+1,r,L,R,val);
	sum[o]=sum[o*2]+sum[o*2+1];
}
long long query(int o,int l,int r,const int L,const int R){
	if(L<=l && r<=R)
		return sum[o];
	pushdown(o,l,r);
	int mid=(l+r)>>1;
	long long cn=0;
	if(mid>=L) cn+=query(o*2,l,mid,L,R);
	if(mid<R) cn+=query(o*2+1,mid+1,r,L,R);
	return cn;
}
那么对于线段树来说,最最最基础的操作就已经介绍完成了,欢迎要练习的同学们去做几道模板题:洛谷p3372 HDU1166
那么新人初学者的第一篇博客就在这里了,希望有人能喜欢。


猜你喜欢

转载自blog.csdn.net/Amuseir/article/details/79301619
今日推荐