fzu 2105 Digits Count(成段更新)

题意:给你N个数,有四种操作。(1)"AND opn L R",表示对区间[L,R]内的数全部与opn进行且(&)操作。(2)"OR opn L R",表示对区间[L,R]内的数全部与opn进行或(|)操作。(3)"XOR opn L R",表示对区间[L,R]内的数全部与opn进行异或(^)操作。(4)"SUM opn L R",求出区间[L,R]的数的和。

注意到给出的N个数的和opn的范围只有0-15(用二进制看,就是只有4位),把每个数的每一位单独拉出来,建一次线段树,得到每个询问内每一位有多少个1,就得到最后的答案。

说白了,就是全部赋值为1或者0,并且把区间里面0和1的数目对换,即翻转。我们定义前者为操作1(赋值为0或1可以合并起来处理不是?),后者为操作2。那么如果先进行操作2,再进行操作1,操作2是失效的,因为无论先怎么交换,再把它们赋值成0或1,结果都是一样的。但是,如果先进行操作1,再无进行操作2,结果是不一样的,因为先赋值,再翻转的话,就会把原来全部赋值为1的变成全部为0,反之同理。所以我们必须先传递操作1,再传递操作2,同时在传递操作1时,如果有操作2的标记,那么我们要把操作2的标记去除(因为它失效了)。

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define LL(x) (x<<1)
#define RR(x) (x<<1|1)
#define MID(a,b) (a+((b-a)>>1))
const int N=1e6+5;
int a[N],n,m;
struct Segtree
{
    int cnt0[N*4],cnt1[N*4],flag[N*4],flag_x[N*4];
    void fun(int ind,int iflag,int iflag_x,int lft,int rht)
    {
        if(iflag!=-1)
        {
            flag[ind]=iflag;
            if(iflag==0) cnt0[ind]=rht-lft+1,cnt1[ind]=0;
            else if(iflag==1) cnt1[ind]=rht-lft+1,cnt0[ind]=0;
            flag_x[ind]=0;
        }

        if(iflag_x)
        {
            flag_x[ind]^=1;
            swap(cnt0[ind],cnt1[ind]);
        }
    }

    void PushDown(int ind,int lft,int rht)
    {
        if(flag[ind]!=-1||flag_x[ind]==1)
        {
            int mid=MID(lft,rht);
            fun(LL(ind),flag[ind],flag_x[ind],lft,mid);
            fun(RR(ind),flag[ind],flag_x[ind],mid+1,rht);
            flag[ind]=-1;   flag_x[ind]=0;
        }
    }
    void PushUp(int ind)
    {
        cnt0[ind]=cnt0[LL(ind)]+cnt0[RR(ind)];
        cnt1[ind]=cnt1[LL(ind)]+cnt1[RR(ind)];
    }
    void build(int lft,int rht,int ind,int bit)
    {
        cnt0[ind]=0;    cnt1[ind]=0;    flag[ind]=-1;   flag_x[ind]=0;
        if(lft==rht)
        {
            if((a[lft]&(1<<bit))==0) cnt0[ind]=1;
            else cnt1[ind]=1;
        }
        else
        {
            int mid=MID(lft,rht);
            build(lft,mid,LL(ind),bit);
            build(mid+1,rht,RR(ind),bit);
            PushUp(ind);
        }
    }
    void updata(int st,int ed,int iflag,int iflag_x,int lft,int rht,int ind)
    {
        if(st<=lft&&rht<=ed) fun(ind,iflag,iflag_x,lft,rht);
        else
        {
            PushDown(ind,lft,rht);
            int mid=MID(lft,rht);
            if(st<=mid) updata(st,ed,iflag,iflag_x,lft,mid,LL(ind));
            if(ed> mid) updata(st,ed,iflag,iflag_x,mid+1,rht,RR(ind));
            PushUp(ind);
        }
    }
    int query(int st,int ed,int lft,int rht,int ind)
    {
        if(st<=lft&&rht<=ed) return cnt1[ind];
        else
        {
            PushDown(ind,lft,rht);
            int mid=MID(lft,rht),num=0;
            if(st<=mid) num+=query(st,ed,lft,mid,LL(ind));
            if(ed> mid) num+=query(st,ed,mid+1,rht,RR(ind));
            PushUp(ind);
            return num;
        }
    }
}seg;
struct OP
{
    char cmd[10];
    int opn,L,R,cnt[5];
    void get()
    {
        memset(cnt,0,sizeof(cnt));
        scanf("%s",cmd);
        if(cmd[0]=='S') scanf("%d%d",&L,&R);
        else scanf("%d%d%d",&opn,&L,&R);
        L=max(0,L); R=min(n-1,R);
    }
    void print()
    {
        for(int i=0;i<4;i++) printf("%d ",cnt[i]);puts("");
    }
}op[100005];

void solve(int bit)
{
    seg.build(0,n-1,1,bit);

    for(int i=0;i<m;i++)
    {
        if(op[i].cmd[0]=='A')
        {
            if(op[i].opn&(1<<bit)) continue;
            else seg.updata(op[i].L,op[i].R,0,0,0,n-1,1);
        }
        else if(op[i].cmd[0]=='O')
        {
            if(op[i].opn&(1<<bit)) seg.updata(op[i].L,op[i].R,1,0,0,n-1,1);
            else continue;
        }
        else if(op[i].cmd[0]=='X')
        {
            if(op[i].opn&(1<<bit)) seg.updata(op[i].L,op[i].R,-1,1,0,n-1,1);
            else continue;
        }
        else
        {
            op[i].cnt[bit]=seg.query(op[i].L,op[i].R,0,n-1,1);
        }
    }
}
int main()
{
    //freopen("D.in","r",stdin);

    int t;
    scanf("%d",&t);
    while(t--)
    {
        scanf("%d%d",&n,&m);
        for(int i=0;i<n;i++) scanf("%d",&a[i]);
        for(int i=0;i<m;i++) op[i].get();
        for(int i=0;i<4;i++) solve(i);

        for(int i=0;i<m;i++) if(op[i].cmd[0]=='S')
        {
            int ans=op[i].cnt[0]+op[i].cnt[1]*2+op[i].cnt[2]*2*2+op[i].cnt[3]*2*2*2;
           // op[i].print();
            printf("%d\n",ans);
        }
    }
    return 0;
}


猜你喜欢

转载自blog.csdn.net/shiqi_614/article/details/8869522