线段树是算法竞赛中常用的用来维护 区间信息 的数据结构。
线段树可以在 O(logN)
的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。
线段树相对于树状数组代码较长,但是线段树较全面,功能性更强。
线段树的基本分为两类:
- 单点修改,不需要懒标记。
pushup: 由子节点向上更新父节点
bulild: 一段区间初始化成线段树
modify:修改
query:查询
- 区间修改,需要懒标记。
pushup: 由子节点向上更新父节点
pushdown: 由父节点向下更新子节点
bulild: 一段区间初始化成线段树
modify:修改
query:查询
线段树代码实现的两种方法:
- 结构体来实现,结构体内保存了各个信息。
- 数组来实现,则需要在写线段树函数的时候需要额外的维护信息。
线段树需要开四倍的原因:
我们不难看出上图的倒数第二行最多为n
,那么上面就是n-1
那么最后一行最多为2*n
故总共最多为4*n-1
下面我们来具体的看一下线段树代码实现的两种方法:
用结构体来实现线段树:
#include<bits/stdc++.h>
using namespace std;
const int N=1e5*5+10;
struct node{
int l,r,sum;}tr[N*4];// sum表示 [l,r] 内的和
int n,m,a[N];
void pushup(int u)
{
tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
}
void build(int u,int l,int r)
{
tr[u]={
l,r};
if(l==r)
{
tr[u].sum=a[l];
return;//子节点
}
int mid=l+r>>1;
build(u*2,l,mid);
build(u*2+1,mid+1,r);
pushup(u);
}
void modify(int u,int x,int v)
{
if(tr[u].l==x&&tr[u].r==x)//子节点
{
tr[u].sum+=v;
return;
}
else
{
int mid=(tr[u].l+tr[u].r)/2;
if(x<=mid) modify(u*2,x,v);//在左边
else modify(u*2+1,x,v);//右边
pushup(u);//用子节点的信息更新父节点
}
}
int query(int u,int l,int r)
{
if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sum;//完全包含
else
{
int mid=(tr[u].l+tr[u].r)/2;
int sum=0;
if(l<=mid) sum+=query(u*2,l,r);//与左边有交集
if(r>mid) sum+=query(u*2+1,l,r);//与右边有交集
return sum;
}
}
int main(void)
{
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i];
build(1,1,n);
for(int i=0;i<m;i++)
{
int op; cin>>op;
if(op==1)
{
int x,k; cin>>x>>k;
modify(1,x,k);
}
else
{
int l,r; cin>>l>>r;
cout<<query(1,l,r)<<endl;
}
}
return 0;
}
数组实现线段树:
#include<bits/stdc++.h>
using namespace std;
const int N=1e5*5+10;
int a[N],f[N*4],n,m;//f[i]表示编号为i的区间的和
void build(int u,int l,int r)
{
if(l==r)
{
f[u]=a[l];
return;
}
int mid=l+r>>1;
build(u*2,l,mid);
build(u*2+1,mid+1,r);
f[u]=f[u*2]+f[u*2+1];
}
void add(int u,int l,int r,int x,int v)
{
f[u]+=v;
if(l==r) return;
int mid=l+r>>1;
if(x<=mid) add(u*2,l,mid,x,v);
else add(u*2+1,mid+1,r,x,v);
}
int query(int u,int l,int r,int l1,int r1)
{
if(l==l1&&r==r1) return f[u];
int mid=l+r>>1;
if(mid>=r1) return query(u*2,l,mid,l1,r1);
else if(l1>=mid+1) return query(u*2+1,mid+1,r,l1,r1);
else
{
int sum=0;
sum+=query(u*2,l,mid,l1,mid);
sum+=query(u*2+1,mid+1,r,mid+1,r1);
return sum;
}
}
int main(void)
{
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i];
build(1,1,n);
for(int i=0;i<m;i++)
{
int op; cin>>op;
if(op==1)
{
int x,k; cin>>x>>k;
add(1,1,n,x,k);
}
else
{
int l,r; cin>>l>>r;
cout<<query(1,1,n,l,r)<<endl;
}
}
return 0;
}