bzoj 1500: [NOI2005]维修数列(splay大模板)

没什么好说的,就是各种操作。

st数组用来回收内存,root为根,tot为点数,lmax为左端最大值,rmax为右端最大值,smax为全区

间最大值,sum为区间和(这部分操作见),val为权值,son【2】为左右子节点编号,siz为大小,

xg为是否有修改标记,fz为是否有反转标记。(注释部分可以不写,有助于理解)。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long ll;
int n,m,st[500005],tot,a[500005],root,pos,how,kk,top;
char ss[15];
struct node
{
    int lmax,rmax,smax,sum;
    int val,fa,son[2],siz;
    int xg,fz;
}tr[500005];
void update(int x)
{
    if(!x)return;
    int ls=tr[x].son[0],rs=tr[x].son[1];
    tr[x].sum=tr[ls].sum+tr[rs].sum+tr[x].val;
    tr[x].siz=tr[ls].siz+tr[rs].siz+1;
    tr[x].lmax=max(tr[ls].lmax,tr[ls].sum+max(tr[rs].lmax,0)+tr[x].val);
    tr[x].rmax=max(tr[rs].rmax,max(tr[ls].rmax,0)+tr[rs].sum+tr[x].val);
    tr[x].smax=max(max(tr[ls].smax,tr[rs].smax),max(tr[ls].rmax,0)+max(tr[rs].lmax,0)+tr[x].val);
}
void pushdown(int x)
{
    int l=tr[x].son[0],r=tr[x].son[1];
    if(tr[x].xg)
    {
        tr[x].xg=tr[x].fz=0;
        if(l)tr[l].xg=1,tr[l].val=tr[x].val,tr[l].sum=tr[l].siz*tr[l].val;
        if(r)tr[r].xg=1,tr[r].val=tr[x].val,tr[r].sum=tr[r].siz*tr[r].val;
        if(l)tr[l].lmax=tr[l].rmax=tr[l].smax=max(tr[l].sum,tr[l].val);//另外操作 
        if(r)tr[r].lmax=tr[r].rmax=tr[r].smax=max(tr[r].sum,tr[r].val);
    }
    if(tr[x].fz)
    {
        tr[x].fz=0;
        tr[l].fz^=1,tr[r].fz^=1;
        swap(tr[l].son[0],tr[l].son[1]);
        swap(tr[r].son[0],tr[r].son[1]);
        swap(tr[l].lmax,tr[l].rmax);
        swap(tr[r].lmax,tr[r].rmax);
    }
}
void rotate(int x)
{
    int y=tr[x].fa,z=tr[y].fa;
    int typ=(x==tr[y].son[1]);
    tr[y].son[typ]=tr[x].son[typ^1],tr[tr[x].son[typ^1]].fa=y;
    tr[x].son[typ^1]=y,tr[y].fa=x;
    tr[x].fa=z,tr[z].son[tr[z].son[1]==y]=x;
    update(y);update(x);
}
void splay(int x,int goal)
{
    for(int y;(y=tr[x].fa)!=goal;rotate(x))
    {
        if(tr[y].fa!=goal)rotate((x==tr[y].son[0])==(y==tr[tr[y].fa].son[0])?y:x);
    }
    if(!goal)root=x;
}
int biuld(int l,int r,int f)
{
    if(l>r)return 0;
    int mid=(l+r)/2;
    int rt=top?st[top--]:++tot;
    tr[rt].lmax=tr[rt].rmax=tr[rt].smax=tr[rt].val=tr[rt].sum=a[mid];
    tr[rt].siz=1;tr[rt].fa=f;tr[rt].son[0]=tr[rt].son[1]=0;
    tr[rt].fz=0;//这段主要与下面这句对应,也可以写在rip里 
    //if(l==r)return rt;
    tr[rt].son[0]=biuld(l,mid-1,rt);
    tr[rt].son[1]=biuld(mid+1,r,rt);
    update(rt);
    return rt;
}
int findkth(int x,int k)
{
    pushdown(x);
    int t=tr[tr[x].son[0]].siz;
    if(k<=t)return findkth(tr[x].son[0],k);
    else if(k==t+1)return x;
    else return findkth(tr[x].son[1],k-1-t);
}
void insert(int po,int tt)
{
    for(int i=1;i<=tt;i++)scanf("%d",&a[i]);
    int x=findkth(root,po+1),y=findkth(root,po+2);
    splay(x,0);
    splay(y,x);
    int z=biuld(1,tt,y);
    tr[y].son[0]=z;//tr[z].fa=y;
    update(y);update(x);
}
void rip(int x)
{
    if(!x)return;
    st[++top]=x;
    rip(tr[x].son[0]);
    rip(tr[x].son[1]);
    //tr[x].xg=tr[x].fz=tr[x].val=tr[x].sum=0;
    //tr[x].siz=tr[x].smax=tr[x].lmax=tr[x].rmax=tr[x].son[0]=tr[x].son[1]=0;
}
int deal(int x,int y)
{
    x=findkth(root,x);
    y=findkth(root,y);
    splay(x,0);
    splay(y,x);
    return tr[y].son[0];
}
void erase(int l,int r)
{
    int x=deal(l,r),y=tr[x].fa;
    rip(x);tr[y].son[0]=0;//tr[x].fa=0;
    update(y);update(root);
}
void make_same(int l,int r,int k)
{
    int x=deal(l,r),y=tr[x].fa;
    tr[x].val=k,tr[x].sum=tr[x].siz*k,tr[x].xg=1;
    tr[x].lmax=tr[x].rmax=tr[x].smax=max(tr[x].sum,tr[x].val);
    update(y);update(root);
}
void rever(int l,int r)
{
    int x=deal(l,r),y=tr[x].fa;
    if(!tr[x].xg)
    {
        tr[x].fz^=1;
        swap(tr[x].son[0],tr[x].son[1]);
        swap(tr[x].lmax,tr[x].rmax);
        update(y),update(root);
    }
}
int get_sum(int l,int r)
{
    int x=deal(l,r);
    return tr[x].sum;
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=2;i<=n+1;i++)scanf("%d",&a[i]);
    /*tr[0].lmax=tr[0].rmax=tr[0].val=*/tr[0].smax=a[1]=a[n+2]=-0x3f3f3f3f;
    root=biuld(1,n+2,0);
    for(int i=1;i<=m;i++) 
    {
        scanf("%s",ss);
        if(ss[0]=='I')
        {
            scanf("%d%d",&pos,&how);
            insert(pos,how);
        }else if(ss[0]=='D')
        {
            scanf("%d%d",&pos,&how);
            erase(pos,how+pos+1);
        }else if(ss[2]=='K')
        {
            scanf("%d%d%d",&pos,&how,&kk);
            make_same(pos,pos+1+how,kk);
        }else if(ss[0]=='R')
        {
            scanf("%d%d",&pos,&how);
            rever(pos,pos+how+1); 
        }else if(ss[0]=='G')
        {
            scanf("%d%d",&pos,&how);
            printf("%d\n",get_sum(pos,pos+how+1));
        }else
        {
            printf("%d\n",tr[root].smax);
        }
    }
}

猜你喜欢

转载自blog.csdn.net/zzk_233/article/details/82712055
今日推荐