luogu P4428 [BJOI2018]二进制

luogu

先考虑怎样的二进制串才会被3整除.可以发现如果二进制位第\(0,2,4...2n\)位如果为\(1\),那么在模3意义下为1,如果二进制位第\(1,3,5...2n+1\)位如果为\(1\),那么在模3意义下为-1.所以也就是位置上是1的奇二进制位个数减位置上是1的偶二进制位个数要被3整除

在这种条件下,如果区间内1的个数为偶数显然可以从最低位开始依次放使得被3整除,如果为奇数,那么先把除了最后三个1以外的1按照偶数的情况处理,然后这三个1中间各插入一个0,也就是\(...0101011...1\).那么,不合法的情况就只剩下有区间内奇数个1同时0的个数\(<2\),或者是区间内只有一个1

合法区间比较麻烦,改为求总区间个数-不合法区间个数.为了不算重,把不合法条件改为只剩下有区间内奇数个1同时0的个数\(<2\),或者是区间内只有一个1同时\(\ge 2\).我们用线段树维护这些区间个数,对每个节点记一个\(ls_{i,j}\)表示左端点为这个线段树节点对应区间左端点的区间中,1的个数奇偶性\(0/1\),0的个数为\(0/1\)的区间个数,\(rs_{i,j}\)表示的是右端点为线段树节点右端点的相应的区间个数;\(lz_{i,j}\)表示左端点为线段树节点左端点的区间中,1的个数为\(0/1\),0的个数为\(0/1/\ge 2\)的区间个数,\(rz_{i,j}\)表示的是右端点为线段树节点右端点的相应的区间个数.以及分别记录区间\(0/1\)个数和不合法区间个数,每次合并两个节点,就计算跨越这两个节点的区间信息,可能需要一点点讨论,这里不再赘述

#include<bits/stdc++.h>
#define LL long long
#define uLL unsigned long long
#define db double

using namespace std;
const int N=1e5+10;
int rd()
{
    int x=0,w=1;char ch=0;
    while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
    return x*w;
}
struct node
{
    LL c0,c1,s;
    LL ls[2][2],rs[2][2];
    LL lz[2][3],rz[2][3];
    void clr(){memset(ls,0,sizeof(ls)),memset(rs,0,sizeof(rs)),memset(lz,0,sizeof(lz)),memset(rz,0,sizeof(rz)),c0=c1=s=0;}
    node(){}
    node(int x)
    {
        memset(ls,0,sizeof(ls)),memset(rs,0,sizeof(rs)),memset(lz,0,sizeof(lz)),memset(rz,0,sizeof(rz)),c0=c1=s=0;
        if(!x)
        {
            c0=1;
            ls[0][1]=rs[0][1]=lz[0][1]=rz[0][1]=1;
        }
        else
        {
            s=c1=1;
            ls[1][0]=rs[1][0]=lz[1][0]=rz[1][0]=1;
        }
    }
}s[N<<2],an;
node merg(node aa,node bb)
{
    an.clr();
    an.c0=aa.c0+bb.c0;
    an.c1=aa.c1+bb.c1;
    an.s=aa.s+bb.s;
    for(int i=0;i<=1;++i)
        for(int j=0;j<=1;++j)
        {
            an.ls[i][j]+=aa.ls[i][j];
            an.rs[i][j]+=bb.rs[i][j];
            if(aa.c0+j<=1) an.ls[(aa.c1&1)^i][aa.c0+j]+=bb.ls[i][j];
            if(bb.c0+j<=1) an.rs[(bb.c1&1)^i][bb.c0+j]+=aa.rs[i][j];
        }
    for(int i=0;i<=1;++i)
        for(int j=0;j<=1;++j)
            for(int k=0;k<=1;++k)
                for(int l=0;l<=1;++l)
                    if((i^k)==1&&j+l<=1) an.s+=aa.rs[i][j]*bb.ls[k][l];
    for(int i=0;i<=1;++i)
        for(int j=0;j<=2;++j)
        {
            an.lz[i][j]+=aa.lz[i][j];
            an.rz[i][j]+=bb.rz[i][j];
            if(aa.c1+i<=1) an.lz[aa.c1+i][min(aa.c0+j,2ll)]+=bb.lz[i][j];
            if(bb.c1+i<=1) an.rz[bb.c1+i][min(bb.c0+j,2ll)]+=aa.rz[i][j];
        }
    for(int i=0;i<=1;++i)
        for(int j=0;j<=2;++j)
            for(int k=0;k<=1;++k)
                for(int l=0;l<=2;++l)
                    if(i+k==1&&j+l>=2) an.s+=aa.rz[i][j]*bb.lz[k][l];
    return an;
}
int n,a[N];
void psup(int o){s[o]=merg(s[o<<1],s[o<<1|1]);}
void modif(int o,int l,int r,int lx)
{
    if(l==r){a[l]^=1;s[o]=node(a[l]);return;}
    int mid=(l+r)>>1;
    if(lx<=mid) modif(o<<1,l,mid,lx);
    else modif(o<<1|1,mid+1,r,lx);
    psup(o);
}
node quer(int o,int l,int r,int ll,int rr)
{
    if(ll<=l&&r<=rr) return s[o];
    int mid=(l+r)>>1;
    if(rr<=mid) return quer(o<<1,l,mid,ll,rr);
    if(ll>mid) return quer(o<<1|1,mid+1,r,ll,rr);
    return merg(quer(o<<1,l,mid,ll,mid),quer(o<<1|1,mid+1,r,mid+1,rr));
}
void bui(int o,int l,int r)
{
    if(l==r){s[o]=node(a[l]);return;}
    int mid=(l+r)>>1;
    bui(o<<1,l,mid),bui(o<<1|1,mid+1,r);
    psup(o);
}

int main()
{
    n=rd();
    for(int i=1;i<=n;++i) a[i]=rd();
    bui(1,1,n);
    int q=rd();
    while(q--)
    {
        int op=rd();
        if(op==1) modif(1,1,n,rd());
        else
        {
            int l=rd(),r=rd();
            printf("%lld\n",1ll*(r-l+1)*(r-l+2)/2-quer(1,1,n,l,r).s);
        }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/smyjr/p/11668865.html