BZOJ5294 BJOI2018 二进制 线段树

传送门


因为每一位\(\mod 3\)的值为\(1,2,1,2,...\),也就相当于\(1,-1,1,-1,...\)

所以当某个区间的\(1\)的个数为偶数的时候,一定是可行的,只要把这若干个\(1\)放在一起就可以了。

而当某个区间的\(1\)的个数为奇数的时候,那么最优的方式显然是\(1\)\(-1\)两两配对,剩下\(3\)\(1\),然后留下至少\(2\)\(0\),将\(111\)拼成\(10101\)的形式。

注意到\(1\)的个数为\(1\)的时候显然不可行。

所以合法的区间需要满足:\(1\)的个数大于\(1\)、个数为偶数或者个数为奇数且\(0\)的个数\(\geq 2\)

为了方便计算我们将合法区间数量变为总区间数量减去不合法区间数量,而不合法区间需要满足:\(1\)的个数为\(1\),或者个数为奇数且\(0\)的个数\(\leq 1\)

使用线段树维护区间不合法区间数量,对于每一个节点可以维护:区间中不合法区间数量、区间\(0/1\)个数、\(1\)的个数为\(0/1\)的前缀/后缀个数、\(1\)的个数为奇数/偶数且\(0\)的个数为\(0/1\)的前缀/后缀个数、区间左端和右端的数字,就可以计算答案。注意pushup的细节,不要打错就好。

还有一个需要注意的地方:如果用上面的方法,\(10\)\(01\)两种区间会被重复计算,要在计算的时候减掉。

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;

inline int read(){
    int a = 0;
    char c = getchar();
    while(!isdigit(c)) c = getchar();
    while(isdigit(c)){
    a = a * 10 + c - 48;
    c = getchar();
    }
    return a;
}

const int MAXN = 1e5 + 7;
#define int long long
int N , M;
namespace segTree{
    struct node{
        int sum , cnt0 , cnt1 , l , r , lft1[2] , rht1[2] , lft[2][2] , rht[2][2];
    }Tree[MAXN << 2];
    
#define lch (x << 1)
#define rch (x << 1 | 1)
#define mid ((l + r) >> 1)

    node pushup(node a , node b){
    node t;

    t.l = a.l; t.r = b.r;
    
    t.sum = a.sum + b.sum - (a.r ^ b.l);
    t.sum += a.rht1[1] * b.lft1[0] + a.rht1[0] * b.lft1[1];
    t.sum += (a.rht[1][0] + a.rht[1][1]) * (b.lft[0][0] + b.lft[0][1]) - a.rht[1][1] * b.lft[0][1];
    t.sum += (a.rht[0][1] + a.rht[0][0]) * (b.lft[1][0] + b.lft[1][1]) - a.rht[0][1] * b.lft[1][1];

    t.cnt0 = a.cnt0 + b.cnt0; t.cnt1 = a.cnt1 + b.cnt1;

    t.lft1[0] = a.lft1[0] + (a.cnt1 == 0 ? b.lft1[0] : 0);
    t.lft1[1] = a.lft1[1] + (a.cnt1 == 0 ? b.lft1[1] : (a.cnt1 == 1 ? b.lft1[0] : 0));
    t.rht1[0] = b.rht1[0] + (b.cnt1 == 0 ? a.rht1[0] : 0);
    t.rht1[1] = b.rht1[1] + (b.cnt1 == 0 ? a.rht1[1] : (b.cnt1 == 1 ? a.rht1[0] : 0));

    bool f1 = a.cnt1 & 1 , f2 = b.cnt1 & 1;
    t.lft[0][0] = a.lft[0][0] + (a.cnt0 == 0 ? b.lft[f1][0] : 0);
    t.lft[0][1] = a.lft[0][1] + (a.cnt0 == 0 ? b.lft[f1][1] : (a.cnt0 == 1 ? b.lft[f1][0] : 0));
    t.lft[1][0] = a.lft[1][0] + (a.cnt0 == 0 ? b.lft[!f1][0] : 0);
    t.lft[1][1] = a.lft[1][1] + (a.cnt0 == 0 ? b.lft[!f1][1] : (a.cnt0 == 1 ? b.lft[!f1][0] : 0));
    
    t.rht[0][0] = b.rht[0][0] + (b.cnt0 == 0 ? a.rht[f2][0] : 0);
    t.rht[0][1] = b.rht[0][1] + (b.cnt0 == 0 ? a.rht[f2][1] : (b.cnt0 == 1 ? a.rht[f2][0] : 0));
    t.rht[1][0] = b.rht[1][0] + (b.cnt0 == 0 ? a.rht[!f2][0] : 0);
    t.rht[1][1] = b.rht[1][1] + (b.cnt0 == 0 ? a.rht[!f2][1] : (b.cnt0 == 1 ? a.rht[!f2][0] : 0));
    return t;
    }

    void init(int x , int l , int r){
    if(l == r){
        Tree[x].cnt0 = !(Tree[x].sum = Tree[x].cnt1 = read());
        Tree[x].lft[1][0] = Tree[x].rht[1][0] = Tree[x].lft1[1] = Tree[x].rht1[1] = Tree[x].l = Tree[x].r = Tree[x].cnt1;
        Tree[x].lft[0][1] = Tree[x].rht[0][1] = Tree[x].lft1[0] = Tree[x].rht1[0] = Tree[x].cnt0;
    }
    else{
        init(lch , l , mid);
        init(rch , mid + 1 , r);
        Tree[x] = pushup(Tree[lch] , Tree[rch]);
    }
    }

    void modify(int x , int l , int r , int tar){
    if(l == r){
        Tree[x].l ^= 1; Tree[x].r ^= 1;
        Tree[x].cnt0 ^= 1; Tree[x].sum ^= 1; Tree[x].cnt1 ^= 1;
        Tree[x].lft[1][0] ^= 1; Tree[x].rht[1][0] ^= 1; Tree[x].lft1[1] ^= 1; Tree[x].rht1[1] ^= 1;
        Tree[x].lft[0][1] ^= 1; Tree[x].rht[0][1] ^= 1; Tree[x].lft1[0] ^= 1; Tree[x].rht1[0] ^= 1;
        return;
    }
    if(mid >= tar) modify(lch , l , mid , tar);
    else modify(rch , mid + 1 , r , tar);
    Tree[x] = pushup(Tree[lch] , Tree[rch]);
    }

    node query(int x , int l , int r , int L , int R){
    if(l >= L && r <= R) return Tree[x];
    bool f = mid >= L; node t;
    if(f) t = query(lch , l , mid , L , R);
    if(mid < R) t = f ? pushup(t , query(rch , mid + 1 , r , L , R)) : query(rch , mid + 1 , r , L , R);
    return t;
    }
}
using segTree::init; using segTree::modify; using segTree::query;

signed main(){
    #ifndef ONLINE_JUDGE
    freopen("in","r",stdin);
    //freopen("out","w",stdout);
    #endif
    N = read();
    init(1 , 1 , N);
    for(int M = read() ; M ; --M)
    if(read() == 2){
        int l = read() , r = read();
        cout << (r - l + 1) * (r - l + 2) / 2 - query(1 , 1 , N , l , r).sum << endl;
    }
    else modify(1 , 1 , N , read());
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Itst/p/10498675.html