[SCOI2010] 序列操作(线段树)

题目

描述

lxhgww最近收到了一个01序列,序列里面包含了n个数,这些数要么是0,要么是1,现在对于这个序列有五种变换操作和询问操作:
0 a b 把[a, b]区间内的所有数全变成0
1 a b 把[a, b]区间内的所有数全变成1
2 a b 把[a,b]区间内的所有数全部取反,也就是说把所有的0变成1,把所有的1变成0
3 a b 询问[a, b]区间内总共有多少个1
4 a b 询问[a, b]区间内最多有多少个连续的1
对于每一种询问操作,lxhgww都需要给出回答,聪明的程序员们,你们能帮助他吗?

输入

输入数据第一行包括2个数,n和m,分别表示序列的长度和操作数目
第二行包括n个数,表示序列的初始状态
接下来m行,每行3个数,op, a, b,(0<=op<=4,0<=a<=b

输出

对于每一个询问操作,输出一行,包括1个数,表示其对应的答案

输入样例

10 10
0 0 0 1 1 0 1 0 1 1
1 0 2
3 0 5
2 2 2
4 0 4
0 3 6
2 3 7
4 2 8
1 0 5
0 5 6
3 3 9

输出样例

5
2
6
5

说明

对于30%的数据,1<=n, m<=1000
对于100%的数据,1<=n, m<=100000


解题思路

没有区间翻转,那么线段树可以维护。

操作分析

  • 0 a b 把[a, b]区间内的所有数全变成0
  • 1 a b 把[a, b]区间内的所有数全变成1
    以上两个均为区间覆盖操作,因此我们需要一个cov标记(初值为-1表示无覆盖,值为0表示覆盖为0,值为1表示覆盖为1)
  • 2 a b 把[a,b]区间内的所有数全部取反,也就是说把所有的0变成1,把所有的1变成0
    区间反转操作,因此我们还需要一个rev标记,那么我们得好好考虑双标记问题了
  • 3 a b 询问[a, b]区间内总共有多少个1
    区间查询1的个数,其实也就是区间求和,维护一个sum即可
  • 4 a b 询问[a, b]区间内最多有多少个连续的1
    区间查询最大连续子序列问题,我们需要维护每个节点的mx1(最大连续1的个数)、lmx1(从左端点开始往右走最大连续1的个数)、rmx1(从右端点开始往左走最大连续1的个数),又因为这道题有区间反转操作,我们还要维护mx0、lmx0、rmx0。有点恶心……

线段树节点信息

根据上面的分析,线段树节点长这样:

struct segTree{
    int l, r;                       //Basic info.
    int sum, mx[2], lmx[2], rmx[2]; //Maintained info.
    int cov, rev;                   //tags
    segTree(){
        l = r = 0;
        sum = mx[0] = mx[1] = lmx[0] = lmx[1] = rmx[0] = rmx[1] = 0;
        cov = -1, rev = 0;
    }
}tr[N<<2];

双标记问题

相关链接

这道题需要两个标记,而这两个标记又互相有影响,因此我们需要给这两个标记定义优先级

  • 若cov优先级更高,则每次cov操作后要将rev清空(其实此时不清空也是正确的,只不过rev标记没有了意义),pushdown时先下放cov标记
  • 若rev优先级更高,则每次rev操作后要将cov取反(cov ^= 1),这样才能保证操作的正确性,pushdown时先下放rev标记

事实上,这两种方式都是可以的,只不过后一种稍显麻烦一点(两种代码均在文末给出)

求最大连续1的个数

这种问题也挺常见的,平衡树的题中也有(NOI2005维护数列·题解),所以单独拿出来说一下。
这种问题不仅要存一个mx(最大连续1的个数),还要存 lmx(从左端点开始往右走最大连续1的个数)和 rmx(从右端点开始往左走最大连续1的个数),向上更新时分是否跨越区间维护,应该还是比较好理解。

    inline void pushup(int id){
        tr[id].sum = tr[lid].sum + tr[rid].sum;
        for(int i = 0; i <= 1; i++){
            tr[id].mx[i] = max(max(tr[lid].mx[i], tr[rid].mx[i]), tr[lid].rmx[i] + tr[rid].lmx[i]);
            if(tr[lid].lmx[i] == size(lid)) tr[id].lmx[i] = tr[lid].lmx[i] + tr[rid].lmx[i];
            else    tr[id].lmx[i] = tr[lid].lmx[i];
            if(tr[rid].rmx[i] == size(rid)) tr[id].rmx[i] = tr[rid].rmx[i] + tr[lid].rmx[i];
            else    tr[id].rmx[i] = tr[rid].rmx[i];
        }
    }

询问处理

在解决询问时,我们需要节点的很多信息,所以为了降低常数,可以把询问函数类型定义为线段树结构体类型,方便处理


两份代码写的时间隔了几个月,所以变量名和风格稍有不同……

Code#1

cov优先

#include<cstdio>
#include<algorithm>

#define lid id<<1
#define rid id<<1|1
#define mid ((tr[id].l+tr[id].r)>>1)
#define size(id) (tr[id].r-tr[id].l+1)

using namespace std;

const int N = 100005;
int n, m, a[N], opt, ql, qr;

struct segTree{
    int l, r;                       //Basic info.
    int sum, mx[2], lmx[2], rmx[2]; //Maintained info.
    int cov, rev;                   //tags
    segTree(){
        l = r = 0;
        sum = mx[0] = mx[1] = lmx[0] = lmx[1] = rmx[0] = rmx[1] = 0;
        cov = -1, rev = 0;
    }
}tr[N<<2];

struct OPT_segTree{
    inline void pushup(int id){
        tr[id].sum = tr[lid].sum + tr[rid].sum;
        for(int i = 0; i <= 1; i++){
            tr[id].mx[i] = max(max(tr[lid].mx[i], tr[rid].mx[i]), tr[lid].rmx[i] + tr[rid].lmx[i]);
            if(tr[lid].lmx[i] == size(lid)) tr[id].lmx[i] = tr[lid].lmx[i] + tr[rid].lmx[i];
            else    tr[id].lmx[i] = tr[lid].lmx[i];
            if(tr[rid].rmx[i] == size(rid)) tr[id].rmx[i] = tr[rid].rmx[i] + tr[lid].rmx[i];
            else    tr[id].rmx[i] = tr[rid].rmx[i];
        }
    }
    inline void pushdown(int id){
        if(!id || tr[id].l == tr[id].r) return;
        if(tr[id].cov != -1){
            tr[lid].sum = tr[id].cov * size(lid);
            tr[lid].lmx[tr[id].cov] = tr[lid].rmx[tr[id].cov] = tr[lid].mx[tr[id].cov] = size(lid);
            tr[lid].lmx[!tr[id].cov] = tr[lid].rmx[!tr[id].cov] = tr[lid].mx[!tr[id].cov] = 0;
            tr[lid].cov = tr[id].cov, tr[lid].rev = 0;
            tr[rid].sum = tr[id].cov * size(rid);
            tr[rid].lmx[tr[id].cov] = tr[rid].rmx[tr[id].cov] = tr[rid].mx[tr[id].cov] = size(rid);
            tr[rid].lmx[!tr[id].cov] = tr[rid].rmx[!tr[id].cov] = tr[rid].mx[!tr[id].cov] = 0;
            tr[rid].cov = tr[id].cov, tr[rid].rev = 0;
            tr[id].cov = -1;
        }
        if(tr[id].rev){
            swap(tr[lid].mx[0], tr[lid].mx[1]);
            swap(tr[lid].lmx[0], tr[lid].lmx[1]);
            swap(tr[lid].rmx[0], tr[lid].rmx[1]);
            tr[lid].sum = size(lid) - tr[lid].sum;
            tr[lid].rev ^= 1;
            swap(tr[rid].mx[0], tr[rid].mx[1]);
            swap(tr[rid].lmx[0], tr[rid].lmx[1]);
            swap(tr[rid].rmx[0], tr[rid].rmx[1]);
            tr[rid].sum = size(rid) - tr[rid].sum;
            tr[rid].rev ^= 1;
            tr[id].rev = 0;
        }
    }
    void build(int id, int l, int r){
        tr[id].l = l, tr[id].r = r;
        if(tr[id].l == tr[id].r){
            if(a[l] == 1)   tr[id].sum = tr[id].lmx[1] = tr[id].rmx[1] = tr[id].mx[1] = 1;
            if(a[l] == 0)   tr[id].sum = 0, tr[id].lmx[0] = tr[id].rmx[0] = tr[id].mx[0] = 1;
            return;
        }
        build(lid, l, mid);
        build(rid, mid+1, r);
        pushup(id);
    }
    void cover(int id, int l, int r, int val){
        pushdown(id);
        if(tr[id].l == l && tr[id].r == r){
            tr[id].sum = val * size(id);
            tr[id].lmx[val] = tr[id].rmx[val] = tr[id].mx[val] = size(id);
            tr[id].lmx[!val] = tr[id].rmx[!val] = tr[id].mx[!val] = 0;
            tr[id].cov = val, tr[id].rev = 0;
            return;
        }
        if(r <= mid)    cover(lid, l, r, val);
        else if(l > mid)    cover(rid, l, r, val);
        else    cover(lid, l, mid, val), cover(rid, mid+1, r, val);
        pushup(id);
    }
    void reverse(int id, int l, int r){
        pushdown(id);
        if(tr[id].l == l && tr[id].r == r){
            swap(tr[id].mx[0], tr[id].mx[1]);
            swap(tr[id].lmx[0], tr[id].lmx[1]);
            swap(tr[id].rmx[0], tr[id].rmx[1]);
            tr[id].sum = size(id) - tr[id].sum;
            tr[id].rev ^= 1;
            return;
        }
        if(r <= mid)    reverse(lid, l, r);
        else if(l > mid)    reverse(rid, l, r);
        else    reverse(lid, l, mid), reverse(rid, mid+1, r);
        pushup(id);
    }
    int querySum(int id, int l, int r){
        pushdown(id);
        if(tr[id].l == l && tr[id].r == r)  return tr[id].sum;
        if(r <= mid)    return querySum(lid, l, r);
        else if(l > mid)    return querySum(rid, l, r);
        else    return querySum(lid, l, mid) + querySum(rid, mid+1, r);
    }
    segTree querySub(int id, int l, int r){
        pushdown(id);
        if(tr[id].l == l && tr[id].r == r)  return tr[id];
        if(r <= mid)    return querySub(lid, l, r);
        else if(l > mid)    return querySub(rid, l, r);
        else{
            segTree L = querySub(lid, l, mid), R = querySub(rid, mid+1, r), res;
            res.l = l, res.r = r;
            res.sum = L.sum + R.sum;
            for(int i = 0; i <= 1; i++){
                res.mx[i] = max(max(L.mx[i], R.mx[i]), L.rmx[i] + R.lmx[i]);
                if(L.lmx[i] == L.r - L.l + 1)   res.lmx[i] = L.lmx[i] + R.lmx[i];
                else    res.lmx[i] = L.lmx[i];
                if(R.rmx[i] == R.r - R.l + 1)   res.rmx[i] = R.rmx[i] + L.rmx[i];
                else    res.rmx[i] = R.rmx[i];
            }
            return res;
        }
    }
}Seg;

int main(){
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
    Seg.build(1, 1, n);
    while(m--){
        scanf("%d%d%d", &opt, &ql, &qr);
        ql++, qr++;
        if(opt == 0)    Seg.cover(1, ql, qr, 0);
        else if(opt == 1)   Seg.cover(1, ql, qr, 1);
        else if(opt == 2)   Seg.reverse(1, ql, qr);
        else if(opt == 3)   printf("%d\n", Seg.querySum(1, ql, qr));
        else if(opt == 4)   printf("%d\n", Seg.querySub(1, ql, qr).mx[1]);
    }
    return 0;
}

Code#2

rev优先

#include<cstdio>
#include<algorithm>

#define lid id<<1
#define rid id<<1|1
#define mid ((tr[id].l + tr[id].r) >> 1)
#define len(id) (tr[id].r - tr[id].l + 1)

using namespace std;

inline int read(){
    int x = 0;
    bool fl = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9'){
        if(ch == '-')   fl = 0;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9'){
        x = (x << 1) + (x << 3) + ch - '0';
        ch = getchar();
    }
    return fl ? x : -x;
}

const int N = 100005;
int n, q, a[N], opt, ql, qr;

struct seg_tree{
    int l, r;
    int sum, lenl[2], lenr[2], len[2];//七个变量:多少个 1,从左、从右、区间最长连续 0/1
    int rev, cov;//两个标记:取反标记 & 赋值标记 
    void init(){
        sum = lenl[0] = lenr[0] = lenl[1] = lenr[1] = 0;
        cov = -1, rev = 0;//注意初值 
    }
}tr[N<<2];

void pushup(int id){
    tr[id].sum = tr[lid].sum + tr[rid].sum;
    if(tr[lid].lenl[0] == len(lid)) tr[id].lenl[0] = tr[lid].lenl[0] + tr[rid].lenl[0];
    else    tr[id].lenl[0] = tr[lid].lenl[0];
    if(tr[lid].lenl[1] == len(lid)) tr[id].lenl[1] = tr[lid].lenl[1] + tr[rid].lenl[1];
    else    tr[id].lenl[1] = tr[lid].lenl[1];
    if(tr[rid].lenr[0] == len(rid)) tr[id].lenr[0] = tr[lid].lenr[0] + tr[rid].lenr[0];
    else    tr[id].lenr[0] = tr[rid].lenr[0];
    if(tr[rid].lenr[1] == len(rid)) tr[id].lenr[1] = tr[lid].lenr[1] + tr[rid].lenr[1];
    else    tr[id].lenr[1] = tr[rid].lenr[1];
    tr[id].len[0] = max(tr[lid].len[0], max(tr[rid].len[0], tr[lid].lenr[0] + tr[rid].lenl[0]));
    tr[id].len[1] = max(tr[lid].len[1], max(tr[rid].len[1], tr[lid].lenr[1] + tr[rid].lenl[1]));
}

void pushdown(int id){
    if(tr[id].l == tr[id].r)    return;
    if(tr[id].rev){
        swap(tr[lid].lenl[0], tr[lid].lenl[1]), swap(tr[rid].lenl[0], tr[rid].lenl[1]);
        swap(tr[lid].lenr[0], tr[lid].lenr[1]), swap(tr[rid].lenr[0], tr[rid].lenr[1]);
        swap(tr[lid].len[0], tr[lid].len[1]), swap(tr[rid].len[0], tr[rid].len[1]);
        tr[lid].sum = len(lid) - tr[lid].sum, tr[rid].sum = len(rid) - tr[rid].sum;
        tr[lid].rev ^= tr[id].rev, tr[rid].rev ^= tr[id].rev;
        if(tr[lid].cov != -1)   tr[lid].cov ^= tr[id].rev;
        if(tr[rid].cov != -1)   tr[rid].cov ^= tr[id].rev;
        tr[id].rev = 0;
    }
    if(tr[id].cov != -1){
        tr[lid].cov = tr[rid].cov = tr[id].cov;
        tr[lid].sum = len(lid) * tr[id].cov;
        tr[rid].sum = len(rid) * tr[id].cov;
        tr[lid].len[tr[id].cov^1] = tr[lid].lenl[tr[id].cov^1] = tr[lid].lenr[tr[id].cov^1] = 0;
        tr[rid].len[tr[id].cov^1] = tr[rid].lenl[tr[id].cov^1] = tr[rid].lenr[tr[id].cov^1] = 0;
        tr[lid].len[tr[id].cov] = tr[lid].lenl[tr[id].cov] = tr[lid].lenr[tr[id].cov] = len(lid);
        tr[rid].len[tr[id].cov] = tr[rid].lenl[tr[id].cov] = tr[rid].lenr[tr[id].cov] = len(rid);
        tr[id].cov = -1;
    }
}

void build(int id, int l, int r){
    tr[id].init();
    tr[id].l = l, tr[id].r = r;
    if(tr[id].l == tr[id].r){
        tr[id].sum = a[l];
        if(a[l] == 0){
            tr[id].lenl[0] = tr[id].lenr[0] = tr[id].len[0] = 1;
            tr[id].lenl[1] = tr[id].lenr[1] = tr[id].len[1] = 0;
        }
        else if(a[l] == 1){
            tr[id].lenl[0] = tr[id].lenr[0] = tr[id].len[0] = 0;
            tr[id].lenl[1] = tr[id].lenr[1] = tr[id].len[1] = 1;
        }
        return;
    }
    build(lid, l, mid);
    build(rid, mid+1, r);
    pushup(id);
}

void modify_cover(int id, int l, int r, int val){
    pushdown(id);
    if(tr[id].l == l && tr[id].r == r){
        tr[id].cov = val;
        tr[id].rev = 0;
        tr[id].sum = len(id) * val;
        tr[id].lenl[val] = tr[id].lenr[val] = tr[id].len[val] = len(id);
        tr[id].lenl[val^1] = tr[id].lenr[val^1] = tr[id].len[val^1] = 0;
        return;
    }
    if(r <= mid)    modify_cover(lid, l, r, val);
    else if(l > mid)    modify_cover(rid, l, r, val);
    else    modify_cover(lid, l, mid, val), modify_cover(rid, mid+1, r, val);
    pushup(id);
}

void modify_rev(int id, int l, int r){
    pushdown(id);
    if(tr[id].l == l && tr[id].r == r){
        swap(tr[id].lenl[0], tr[id].lenl[1]);
        swap(tr[id].lenr[0], tr[id].lenr[1]);
        swap(tr[id].len[0], tr[id].len[1]);
        tr[id].sum = len(id) - tr[id].sum;
        tr[id].rev ^= 1;
        if(tr[id].cov != -1)    tr[id].cov ^= 1;
        return;
    }
    if(r <= mid)    modify_rev(lid, l, r);
    else if(l > mid)    modify_rev(rid, l, r);
    else    modify_rev(lid, l, mid), modify_rev(rid, mid+1, r);
    pushup(id);
}

seg_tree query(int id, int l, int r){
    pushdown(id);
    if(tr[id].l == l && tr[id].r == r)
        return tr[id];
    if(r <= mid)    return query(lid, l, r);
    else if(l > mid)    return query(rid, l, r);
    else{
        seg_tree t, t1, t2;
        t.init(), t1.init(), t2.init();
        t1 = query(lid, l, mid);
        t2 = query(rid, mid+1, r);
        t.sum = t1.sum + t2.sum;
        if(t1.lenl[0] == len(lid))  t.lenl[0] = t1.lenl[0] + t2.lenl[0];
        else    t.lenl[0] = t1.lenl[0];
        if(t1.lenl[1] == len(lid))  t.lenl[1] = t1.lenl[1] + t2.lenl[1];
        else    t.lenl[1] = t1.lenl[1];
        if(t2.lenr[0] == len(rid))  t.lenr[0] = t1.lenr[0] + t2.lenr[0];
        else    t.lenr[0] = t2.lenr[0];
        if(t2.lenr[1] == len(rid))  t.lenr[1] = t1.lenr[1] + t2.lenr[1];
        else    t.lenr[1] = t2.lenr[1];
        t.len[0] = max(t1.len[0], max(t2.len[0], t1.lenr[0] + t2.lenl[0]));
        t.len[1] = max(t1.len[1], max(t2.len[1], t1.lenr[1] + t2.lenl[1]));
        return t;
    }
}

int main(){
    n = read(), q = read();
    for(int i = 1; i <= n; i++) a[i] = read();
    build(1, 1, n);
    while(q--){
        opt = read(), ql = read(), qr = read();
        ql++, qr++;
        if(opt == 0)    modify_cover(1, ql, qr, 0);
        else if(opt == 1)   modify_cover(1, ql, qr, 1);
        else if(opt == 2)   modify_rev(1, ql, qr);
        else{
            seg_tree t = query(1, ql, qr);
            if(opt == 3)    printf("%d\n", t.sum);
            else if(opt == 4)   printf("%d\n", t.len[1]);
        }
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/phantomagony/article/details/79189819