【学习笔记】线段树(模板)常见错误汇总

  1. 看清题目
  2. 递归建树时,注意递归式为
build(p*2,l,mid);
build(p*2+1,mid+1,r);

而不是

build(p*2,l,r);
build(p*2+1,l,r);

当然正常人都不会写错

  1. 区间修改时,注意延迟标记的更新和下方具体来说:
if(l<=l(p)&&r>=r(p))
{
    sum(p)+=(long long)d*(r(p)-l(p)+1);
    add(p)+=d; //更新延迟标记
    return ;
}
spread(p); //下放标记
  1. 注意到区间修改每次递归为原序列,mid只是帮助判断,即:
//区间修改
if(l<=mid)change(p*2,l,r,d);
if(r>mid)change(p*2+1,l,r,d);
  1. 注意到查询区间和时,函数 ask 最好返回值得数据类型为 long long ,递归边界条件为if(l<=l(p)&&r>=r(p))return sum(p);,即,当目标区间被当前区间完全包含时,返回该区间的值

  2. 同时涉及区间加和区间乘的题目注意维护两个延迟标记和取模。模板:

void spread(ll p)
{
    sum(p*2)=(ll)(sum(p*2)*mul(p)+(add(p)*(r(p*2)-l(p*2)+1))%mod)%mod;
    sum(p*2+1)=(ll)(sum(p*2+1)*mul(p)+(add(p)*(r(p*2+1)-l(p*2+1)+1))%mod)%mod;
    mul(p*2)=(ll)(mul(p*2)*mul(p))%mod;
    mul(p*2+1)=(ll)(mul(p*2+1)*mul(p))%mod;
    add(p*2)=(ll)(add(p*2)*mul(p)+add(p))%mod;
    add(p*2+1)=(ll)(add(p*2+1)*mul(p)+add(p))%mod;
    add(p)=0,mul(p)=1;
}

线段树模板 1 (单点加,区间求和)

#include <iostream>
#include <cstdio>
#include <cmath>

using namespace std;

struct SegmentTree
{
    int l,r,dat;
    int sum;
};

SegmentTree t[5000001];

int n,m,a[5000001];

inline int read(){
    int s = 0, w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){ if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
    return s * w;
}

void build(int p,int x,int y)
{
    t[p].l=x; t[p].r=y;
    if(x==y){t[p].sum=a[x];return ;}
    int mid=(x+y)/2;
    build(p*2,x,mid);
    build(p*2+1,mid+1,y);
    t[p].sum=t[p*2].sum+t[p*2+1].sum; //自下而上拼起来 
}

void change(int p,int x,int v) //把a[x]+v  p是每一个区间结点的位置 
{
    if(t[p].l==t[p].r){t[p].sum+=v;return ;} //找到就加上,然后return,把随之而变的结点的值(sum)修改 
    int mid=(t[p].l+t[p].r)/2;
    if(x<=mid)change(p*2,x,v); //小于中线就在左子树查找 
    else change(p*2+1,x,v);
    t[p].sum=t[p*2].sum+t[p*2+1].sum; //拼起来 
}

int find(int p,int x,int y)
{
    if(x==t[p].l&&y==t[p].r)return t[p].sum;
    int mid=(t[p].l+t[p].r)/2;
    if(x>mid)return find(p*2+1,x,y); //左端点比中线大,则区间[x,y]都在右子树 
    else if(mid>=y)return find(p*2,x,y); //同上,都在左子树 
    else return find(p*2,x,mid)+find(p*2+1,mid+1,y); //否则分别递归左、右子树 
}
/*
int find(int p,int x,int y)
{
    if(x<=t[p].l&&y>=t[p].r)return t[p].sum;
    int mid=(t[p].l+t[p].r)/2;
    int s=0;
    if(x<=mid)s+=find(p*2,x,y);
    if(y>mid)s+=find(p*2+1,x,y);
    return s;
}
亦可 
*/
int main()
{
    int w,x,y;
    n=read();m=read();
    for(int i=1;i<=n;i++)a[i]=read();
    build(1,1,n);
    while(m--)
    {
        w=read(),x=read(),y=read();
        if(w==1)change(1,x,y);
        else if(w==2)cout<<find(1,x,y)<<endl;
    }
    return 0;
}

线段树模板 2 (区间加,区间求和)

#include <iostream>
#include <cstdio>

using namespace std;

int n,m;
long long a[1000001];

struct SegmentTree
{
    int l,r;
    long long sum,add;
    #define l(x) t[x].l
    #define r(x) t[x].r
    #define sum(x) t[x].sum
    #define add(x) t[x].add
};

SegmentTree t[1000001*4];

inline int read()
{
    int s=0,w=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')s=s*10+ch-'0',ch=getchar();
    return s*w;
}

void build(int p,int l,int r)
{
    l(p)=l,r(p)=r;
    if(l==r){sum(p)=a[l];return ;}
    int mid=(l+r)/2;
    build(p*2,l,mid);
    build(p*2+1,mid+1,r);
    sum(p)=sum(p*2)+sum(p*2+1);
}

void spread(int p)
{
    if(add(p))
    {
        sum(p*2)+=add(p)*(r(p*2)-l(p*2)+1);
        sum(p*2+1)+=add(p)*(r(p*2+1)-l(p*2+1)+1);
        add(p*2)+=add(p);
        add(p*2+1)+=add(p);
        add(p)=0;
    }
}

void change(int p,int l,int r,int d)
{
    if(l<=l(p)&&r>=r(p))
    {
        sum(p)+=(long long)d*(r(p)-l(p)+1);
        add(p)+=d;
        return ;
    }
    spread(p);
    int mid=(l(p)+r(p))/2;
    if(l<=mid)change(p*2,l,r,d);
    if(r>mid)change(p*2+1,l,r,d);
    sum(p)=sum(p*2)+sum(p*2+1);
}

long long ask(int p,int l,int r)
{
    if(l<=l(p)&&r>=r(p))return sum(p);
    spread(p);
    int mid=(l(p)+r(p))/2;
    long long ans=0;
    if(l<=mid)ans+=ask(p*2,l,r);
    if(r>mid)ans+=ask(p*2+1,l,r);
    return ans;
}

int main()
{
    int w,x,y,k;
    n=read(),m=read();
    for(int i=1;i<=n;i++)a[i]=read();
    build(1,1,n);
    while(m--)
    {
        w=read(),x=read(),y=read();
        if(w==1)k=read(),change(1,x,y,k);
        if(w==2)cout<<ask(1,x,y)<<endl;
    }
    return 0;
}

线段树模板 3 (区间加,区间乘,区间求和)

#include <iostream>
#include <cstdio>

#define ll long long 

using namespace std;

int n,m,mod;
ll a[1000001];

struct SegmentTree
{
    int l,r;
    ll sum,add,mul;
    #define l(x) t[x].l
    #define r(x) t[x].r
    #define sum(x) t[x].sum
    #define add(x) t[x].add
    #define mul(x) t[x].mul
};

SegmentTree t[1000001*4];

inline ll read()
{
    int s=0,w=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')s=s*10+ch-'0',ch=getchar();
    return s*w;
}

void build(ll p,ll l,ll r)
{
    mul(p)=1;
    l(p)=l,r(p)=r;
    if(l==r){sum(p)=a[l]%mod;return ;}
    ll mid=(l+r)/2;
    build(p*2,l,mid);
    build(p*2+1,mid+1,r);
    sum(p)=(sum(p*2)+sum(p*2+1))%mod;
}

void spread(ll p)
{
    sum(p*2)=(ll)(sum(p*2)*mul(p)+(add(p)*(r(p*2)-l(p*2)+1))%mod)%mod;
    sum(p*2+1)=(ll)(sum(p*2+1)*mul(p)+(add(p)*(r(p*2+1)-l(p*2+1)+1))%mod)%mod;
    mul(p*2)=(ll)(mul(p*2)*mul(p))%mod;
    mul(p*2+1)=(ll)(mul(p*2+1)*mul(p))%mod;
    add(p*2)=(ll)(add(p*2)*mul(p)+add(p))%mod;
    add(p*2+1)=(ll)(add(p*2+1)*mul(p)+add(p))%mod;
    add(p)=0,mul(p)=1;
}

void change1(ll p,ll l,ll r,ll d)
{
    if(l<=l(p)&&r>=r(p))
    {
        sum(p)=((ll)sum(p)+(d*(r(p)-l(p)+1)))%mod;
        add(p)=(add(p)+d)%mod;
        return ;
    }
    spread(p);
    int mid=(l(p)+r(p))/2;
    if(l<=mid)change1(p*2,l,r,d);
    if(r>mid)change1(p*2+1,l,r,d);
    sum(p)=(sum(p*2)+sum(p*2+1))%mod;
}

void change2(ll p,ll l,ll r,ll d)
{
    if(l<=l(p)&&r>=r(p))
    {
        sum(p)=((ll)d*sum(p))%mod;
        add(p)=(add(p)*d)%mod;
        mul(p)=(mul(p)*d)%mod;
        return ;
    }
    spread(p);
    int mid=(l(p)+r(p))/2;
    if(l<=mid)change2(p*2,l,r,d);
    if(r>mid)change2(p*2+1,l,r,d);
    sum(p)=(sum(p*2)+sum(p*2+1))%mod;
}

ll ask(ll p,ll l,ll r)
{
    if(l<=l(p)&&r>=r(p))return sum(p);
    spread(p);
    int mid=(l(p)+r(p))/2;
    ll ans=0;
    if(l<=mid)ans=(ans+ask(p*2,l,r))%mod;
    if(r>mid)ans=(ans+ask(p*2+1,l,r))%mod;
    return ans%mod;
}

int main()
{
    ll w,x,y,k;
    n=read(),m=read(),mod=read();
    for(int i=1;i<=n;i++)a[i]=read();
    build(1,1,n);
    while(m--)
    {
        w=read(),x=read(),y=read();
        if(w==1)k=read(),change2(1,x,y,k);
        if(w==2)k=read(),change1(1,x,y,k);
        if(w==3)cout<<ask(1,x,y)<<endl;
    }
    return 0;
}

线段树模板4 区间加,单点查询

#include <iostream>
#include <stdio.h>
#include <math.h>
#define ll long long

using namespace std;

int n,m,a[500010];
struct T
{
    int l,r;
    ll v,add;
    #define l(f) t[f].l
    #define r(f) t[f].r
    #define sum(f) t[f].v
    #define add(f) t[f].add 
};
T t[500010*4];

inline int read()
{
    int s=0,w=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')s=s*10+ch-'0',ch=getchar();
    return s*w;
}

void build(int p,int l,int r)
{
    l(p)=l,r(p)=r;
    if(l==r){sum(p)=a[l];return ;}
    int mid=(l+r)>>1;
    build(p*2,l,mid);
    build(p*2+1,mid+1,r);
    sum(p)=sum(p*2)+sum(p*2+1);
}

void spread(int p)
{
    if(add(p))
    {
        sum(p*2)+=(ll)add(p)*(r(p*2)-l(p*2)+1);
        sum(p*2+1)=sum(p*2+1)+(ll)add(p)*(r(p*2+1)-l(p*2+1)+1);
        add(p*2)+=add(p);
        add(p*2+1)+=add(p);
        add(p)=0;
    }
}

void change(int p,int l,int r,int v)
{
    if(l<=l(p)&&r>=r(p))
    {
        sum(p)+=(ll)v*(r(p)-l(p)+1);
        add(p)+=v;
        return ;
    }
    spread(p);
    int mid=(l(p)+r(p))>>1;
    if(l<=mid)change(p*2,l,r,v);
    if(r>mid)change(p*2+1,l,r,v);
    sum(p)=sum(p*2)+sum(p*2+1);
}

long long ask(int p,int x)
{
    if(l(p)==x&&r(p)==x)return sum(p);
    spread(p);
    int mid=(l(p)+r(p))>>1;
    if(x<=mid)return ask(p*2,x);
    if(x>mid)return ask(p*2+1,x);
}

int main()
{
    n=read(),m=read();
    for(int i=1;i<=n;i++)a[i]=read();
    build(1,1,n);
    while(m--)
    {
        int w=read(),x=read();
        if(w==2)printf("%lld\n",ask(1,x));
        else{int y=read(),k=read();change(1,x,y,k);}
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Nicest1919/p/12308594.html
今日推荐