题目
描述
lxhgww最近收到了一个01序列,序列里面包含了n个数,这些数要么是0,要么是1,现在对于这个序列有五种变换操作和询问操作:
0 a b 把[a, b]区间内的所有数全变成0
1 a b 把[a, b]区间内的所有数全变成1
2 a b 把[a,b]区间内的所有数全部取反,也就是说把所有的0变成1,把所有的1变成0
3 a b 询问[a, b]区间内总共有多少个1
4 a b 询问[a, b]区间内最多有多少个连续的1
对于每一种询问操作,lxhgww都需要给出回答,聪明的程序员们,你们能帮助他吗?
输入
输入数据第一行包括2个数,n和m,分别表示序列的长度和操作数目
第二行包括n个数,表示序列的初始状态
接下来m行,每行3个数,op, a, b,(0<=op<=4,0<=a<=b
输出
对于每一个询问操作,输出一行,包括1个数,表示其对应的答案
输入样例
10 10
0 0 0 1 1 0 1 0 1 1
1 0 2
3 0 5
2 2 2
4 0 4
0 3 6
2 3 7
4 2 8
1 0 5
0 5 6
3 3 9
输出样例
5
2
6
5
说明
对于30%的数据,1<=n, m<=1000
对于100%的数据,1<=n, m<=100000
解题思路
没有区间翻转,那么线段树可以维护。
操作分析
- 0 a b 把[a, b]区间内的所有数全变成0
- 1 a b 把[a, b]区间内的所有数全变成1
以上两个均为区间覆盖操作,因此我们需要一个cov标记(初值为-1表示无覆盖,值为0表示覆盖为0,值为1表示覆盖为1) - 2 a b 把[a,b]区间内的所有数全部取反,也就是说把所有的0变成1,把所有的1变成0
区间反转操作,因此我们还需要一个rev标记,那么我们得好好考虑双标记问题了 - 3 a b 询问[a, b]区间内总共有多少个1
区间查询1的个数,其实也就是区间求和,维护一个sum即可 - 4 a b 询问[a, b]区间内最多有多少个连续的1
区间查询最大连续子序列问题,我们需要维护每个节点的mx1(最大连续1的个数)、lmx1(从左端点开始往右走最大连续1的个数)、rmx1(从右端点开始往左走最大连续1的个数),又因为这道题有区间反转操作,我们还要维护mx0、lmx0、rmx0。有点恶心……
线段树节点信息
根据上面的分析,线段树节点长这样:
struct segTree{
int l, r; //Basic info.
int sum, mx[2], lmx[2], rmx[2]; //Maintained info.
int cov, rev; //tags
segTree(){
l = r = 0;
sum = mx[0] = mx[1] = lmx[0] = lmx[1] = rmx[0] = rmx[1] = 0;
cov = -1, rev = 0;
}
}tr[N<<2];
双标记问题
这道题需要两个标记,而这两个标记又互相有影响,因此我们需要给这两个标记定义优先级
- 若cov优先级更高,则每次cov操作后要将rev清空(其实此时不清空也是正确的,只不过rev标记没有了意义),pushdown时先下放cov标记
- 若rev优先级更高,则每次rev操作后要将cov取反(cov ^= 1),这样才能保证操作的正确性,pushdown时先下放rev标记
事实上,这两种方式都是可以的,只不过后一种稍显麻烦一点(两种代码均在文末给出)
求最大连续1的个数
这种问题也挺常见的,平衡树的题中也有(NOI2005维护数列·题解),所以单独拿出来说一下。
这种问题不仅要存一个mx(最大连续1的个数),还要存 lmx(从左端点开始往右走最大连续1的个数)和 rmx(从右端点开始往左走最大连续1的个数),向上更新时分是否跨越区间维护,应该还是比较好理解。
inline void pushup(int id){
tr[id].sum = tr[lid].sum + tr[rid].sum;
for(int i = 0; i <= 1; i++){
tr[id].mx[i] = max(max(tr[lid].mx[i], tr[rid].mx[i]), tr[lid].rmx[i] + tr[rid].lmx[i]);
if(tr[lid].lmx[i] == size(lid)) tr[id].lmx[i] = tr[lid].lmx[i] + tr[rid].lmx[i];
else tr[id].lmx[i] = tr[lid].lmx[i];
if(tr[rid].rmx[i] == size(rid)) tr[id].rmx[i] = tr[rid].rmx[i] + tr[lid].rmx[i];
else tr[id].rmx[i] = tr[rid].rmx[i];
}
}
询问处理
在解决询问时,我们需要节点的很多信息,所以为了降低常数,可以把询问函数类型定义为线段树结构体类型,方便处理
两份代码写的时间隔了几个月,所以变量名和风格稍有不同……
Code#1
cov优先
#include<cstdio>
#include<algorithm>
#define lid id<<1
#define rid id<<1|1
#define mid ((tr[id].l+tr[id].r)>>1)
#define size(id) (tr[id].r-tr[id].l+1)
using namespace std;
const int N = 100005;
int n, m, a[N], opt, ql, qr;
struct segTree{
int l, r; //Basic info.
int sum, mx[2], lmx[2], rmx[2]; //Maintained info.
int cov, rev; //tags
segTree(){
l = r = 0;
sum = mx[0] = mx[1] = lmx[0] = lmx[1] = rmx[0] = rmx[1] = 0;
cov = -1, rev = 0;
}
}tr[N<<2];
struct OPT_segTree{
inline void pushup(int id){
tr[id].sum = tr[lid].sum + tr[rid].sum;
for(int i = 0; i <= 1; i++){
tr[id].mx[i] = max(max(tr[lid].mx[i], tr[rid].mx[i]), tr[lid].rmx[i] + tr[rid].lmx[i]);
if(tr[lid].lmx[i] == size(lid)) tr[id].lmx[i] = tr[lid].lmx[i] + tr[rid].lmx[i];
else tr[id].lmx[i] = tr[lid].lmx[i];
if(tr[rid].rmx[i] == size(rid)) tr[id].rmx[i] = tr[rid].rmx[i] + tr[lid].rmx[i];
else tr[id].rmx[i] = tr[rid].rmx[i];
}
}
inline void pushdown(int id){
if(!id || tr[id].l == tr[id].r) return;
if(tr[id].cov != -1){
tr[lid].sum = tr[id].cov * size(lid);
tr[lid].lmx[tr[id].cov] = tr[lid].rmx[tr[id].cov] = tr[lid].mx[tr[id].cov] = size(lid);
tr[lid].lmx[!tr[id].cov] = tr[lid].rmx[!tr[id].cov] = tr[lid].mx[!tr[id].cov] = 0;
tr[lid].cov = tr[id].cov, tr[lid].rev = 0;
tr[rid].sum = tr[id].cov * size(rid);
tr[rid].lmx[tr[id].cov] = tr[rid].rmx[tr[id].cov] = tr[rid].mx[tr[id].cov] = size(rid);
tr[rid].lmx[!tr[id].cov] = tr[rid].rmx[!tr[id].cov] = tr[rid].mx[!tr[id].cov] = 0;
tr[rid].cov = tr[id].cov, tr[rid].rev = 0;
tr[id].cov = -1;
}
if(tr[id].rev){
swap(tr[lid].mx[0], tr[lid].mx[1]);
swap(tr[lid].lmx[0], tr[lid].lmx[1]);
swap(tr[lid].rmx[0], tr[lid].rmx[1]);
tr[lid].sum = size(lid) - tr[lid].sum;
tr[lid].rev ^= 1;
swap(tr[rid].mx[0], tr[rid].mx[1]);
swap(tr[rid].lmx[0], tr[rid].lmx[1]);
swap(tr[rid].rmx[0], tr[rid].rmx[1]);
tr[rid].sum = size(rid) - tr[rid].sum;
tr[rid].rev ^= 1;
tr[id].rev = 0;
}
}
void build(int id, int l, int r){
tr[id].l = l, tr[id].r = r;
if(tr[id].l == tr[id].r){
if(a[l] == 1) tr[id].sum = tr[id].lmx[1] = tr[id].rmx[1] = tr[id].mx[1] = 1;
if(a[l] == 0) tr[id].sum = 0, tr[id].lmx[0] = tr[id].rmx[0] = tr[id].mx[0] = 1;
return;
}
build(lid, l, mid);
build(rid, mid+1, r);
pushup(id);
}
void cover(int id, int l, int r, int val){
pushdown(id);
if(tr[id].l == l && tr[id].r == r){
tr[id].sum = val * size(id);
tr[id].lmx[val] = tr[id].rmx[val] = tr[id].mx[val] = size(id);
tr[id].lmx[!val] = tr[id].rmx[!val] = tr[id].mx[!val] = 0;
tr[id].cov = val, tr[id].rev = 0;
return;
}
if(r <= mid) cover(lid, l, r, val);
else if(l > mid) cover(rid, l, r, val);
else cover(lid, l, mid, val), cover(rid, mid+1, r, val);
pushup(id);
}
void reverse(int id, int l, int r){
pushdown(id);
if(tr[id].l == l && tr[id].r == r){
swap(tr[id].mx[0], tr[id].mx[1]);
swap(tr[id].lmx[0], tr[id].lmx[1]);
swap(tr[id].rmx[0], tr[id].rmx[1]);
tr[id].sum = size(id) - tr[id].sum;
tr[id].rev ^= 1;
return;
}
if(r <= mid) reverse(lid, l, r);
else if(l > mid) reverse(rid, l, r);
else reverse(lid, l, mid), reverse(rid, mid+1, r);
pushup(id);
}
int querySum(int id, int l, int r){
pushdown(id);
if(tr[id].l == l && tr[id].r == r) return tr[id].sum;
if(r <= mid) return querySum(lid, l, r);
else if(l > mid) return querySum(rid, l, r);
else return querySum(lid, l, mid) + querySum(rid, mid+1, r);
}
segTree querySub(int id, int l, int r){
pushdown(id);
if(tr[id].l == l && tr[id].r == r) return tr[id];
if(r <= mid) return querySub(lid, l, r);
else if(l > mid) return querySub(rid, l, r);
else{
segTree L = querySub(lid, l, mid), R = querySub(rid, mid+1, r), res;
res.l = l, res.r = r;
res.sum = L.sum + R.sum;
for(int i = 0; i <= 1; i++){
res.mx[i] = max(max(L.mx[i], R.mx[i]), L.rmx[i] + R.lmx[i]);
if(L.lmx[i] == L.r - L.l + 1) res.lmx[i] = L.lmx[i] + R.lmx[i];
else res.lmx[i] = L.lmx[i];
if(R.rmx[i] == R.r - R.l + 1) res.rmx[i] = R.rmx[i] + L.rmx[i];
else res.rmx[i] = R.rmx[i];
}
return res;
}
}
}Seg;
int main(){
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
Seg.build(1, 1, n);
while(m--){
scanf("%d%d%d", &opt, &ql, &qr);
ql++, qr++;
if(opt == 0) Seg.cover(1, ql, qr, 0);
else if(opt == 1) Seg.cover(1, ql, qr, 1);
else if(opt == 2) Seg.reverse(1, ql, qr);
else if(opt == 3) printf("%d\n", Seg.querySum(1, ql, qr));
else if(opt == 4) printf("%d\n", Seg.querySub(1, ql, qr).mx[1]);
}
return 0;
}
Code#2
rev优先
#include<cstdio>
#include<algorithm>
#define lid id<<1
#define rid id<<1|1
#define mid ((tr[id].l + tr[id].r) >> 1)
#define len(id) (tr[id].r - tr[id].l + 1)
using namespace std;
inline int read(){
int x = 0;
bool fl = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){
if(ch == '-') fl = 0;
ch = getchar();
}
while(ch >= '0' && ch <= '9'){
x = (x << 1) + (x << 3) + ch - '0';
ch = getchar();
}
return fl ? x : -x;
}
const int N = 100005;
int n, q, a[N], opt, ql, qr;
struct seg_tree{
int l, r;
int sum, lenl[2], lenr[2], len[2];//七个变量:多少个 1,从左、从右、区间最长连续 0/1
int rev, cov;//两个标记:取反标记 & 赋值标记
void init(){
sum = lenl[0] = lenr[0] = lenl[1] = lenr[1] = 0;
cov = -1, rev = 0;//注意初值
}
}tr[N<<2];
void pushup(int id){
tr[id].sum = tr[lid].sum + tr[rid].sum;
if(tr[lid].lenl[0] == len(lid)) tr[id].lenl[0] = tr[lid].lenl[0] + tr[rid].lenl[0];
else tr[id].lenl[0] = tr[lid].lenl[0];
if(tr[lid].lenl[1] == len(lid)) tr[id].lenl[1] = tr[lid].lenl[1] + tr[rid].lenl[1];
else tr[id].lenl[1] = tr[lid].lenl[1];
if(tr[rid].lenr[0] == len(rid)) tr[id].lenr[0] = tr[lid].lenr[0] + tr[rid].lenr[0];
else tr[id].lenr[0] = tr[rid].lenr[0];
if(tr[rid].lenr[1] == len(rid)) tr[id].lenr[1] = tr[lid].lenr[1] + tr[rid].lenr[1];
else tr[id].lenr[1] = tr[rid].lenr[1];
tr[id].len[0] = max(tr[lid].len[0], max(tr[rid].len[0], tr[lid].lenr[0] + tr[rid].lenl[0]));
tr[id].len[1] = max(tr[lid].len[1], max(tr[rid].len[1], tr[lid].lenr[1] + tr[rid].lenl[1]));
}
void pushdown(int id){
if(tr[id].l == tr[id].r) return;
if(tr[id].rev){
swap(tr[lid].lenl[0], tr[lid].lenl[1]), swap(tr[rid].lenl[0], tr[rid].lenl[1]);
swap(tr[lid].lenr[0], tr[lid].lenr[1]), swap(tr[rid].lenr[0], tr[rid].lenr[1]);
swap(tr[lid].len[0], tr[lid].len[1]), swap(tr[rid].len[0], tr[rid].len[1]);
tr[lid].sum = len(lid) - tr[lid].sum, tr[rid].sum = len(rid) - tr[rid].sum;
tr[lid].rev ^= tr[id].rev, tr[rid].rev ^= tr[id].rev;
if(tr[lid].cov != -1) tr[lid].cov ^= tr[id].rev;
if(tr[rid].cov != -1) tr[rid].cov ^= tr[id].rev;
tr[id].rev = 0;
}
if(tr[id].cov != -1){
tr[lid].cov = tr[rid].cov = tr[id].cov;
tr[lid].sum = len(lid) * tr[id].cov;
tr[rid].sum = len(rid) * tr[id].cov;
tr[lid].len[tr[id].cov^1] = tr[lid].lenl[tr[id].cov^1] = tr[lid].lenr[tr[id].cov^1] = 0;
tr[rid].len[tr[id].cov^1] = tr[rid].lenl[tr[id].cov^1] = tr[rid].lenr[tr[id].cov^1] = 0;
tr[lid].len[tr[id].cov] = tr[lid].lenl[tr[id].cov] = tr[lid].lenr[tr[id].cov] = len(lid);
tr[rid].len[tr[id].cov] = tr[rid].lenl[tr[id].cov] = tr[rid].lenr[tr[id].cov] = len(rid);
tr[id].cov = -1;
}
}
void build(int id, int l, int r){
tr[id].init();
tr[id].l = l, tr[id].r = r;
if(tr[id].l == tr[id].r){
tr[id].sum = a[l];
if(a[l] == 0){
tr[id].lenl[0] = tr[id].lenr[0] = tr[id].len[0] = 1;
tr[id].lenl[1] = tr[id].lenr[1] = tr[id].len[1] = 0;
}
else if(a[l] == 1){
tr[id].lenl[0] = tr[id].lenr[0] = tr[id].len[0] = 0;
tr[id].lenl[1] = tr[id].lenr[1] = tr[id].len[1] = 1;
}
return;
}
build(lid, l, mid);
build(rid, mid+1, r);
pushup(id);
}
void modify_cover(int id, int l, int r, int val){
pushdown(id);
if(tr[id].l == l && tr[id].r == r){
tr[id].cov = val;
tr[id].rev = 0;
tr[id].sum = len(id) * val;
tr[id].lenl[val] = tr[id].lenr[val] = tr[id].len[val] = len(id);
tr[id].lenl[val^1] = tr[id].lenr[val^1] = tr[id].len[val^1] = 0;
return;
}
if(r <= mid) modify_cover(lid, l, r, val);
else if(l > mid) modify_cover(rid, l, r, val);
else modify_cover(lid, l, mid, val), modify_cover(rid, mid+1, r, val);
pushup(id);
}
void modify_rev(int id, int l, int r){
pushdown(id);
if(tr[id].l == l && tr[id].r == r){
swap(tr[id].lenl[0], tr[id].lenl[1]);
swap(tr[id].lenr[0], tr[id].lenr[1]);
swap(tr[id].len[0], tr[id].len[1]);
tr[id].sum = len(id) - tr[id].sum;
tr[id].rev ^= 1;
if(tr[id].cov != -1) tr[id].cov ^= 1;
return;
}
if(r <= mid) modify_rev(lid, l, r);
else if(l > mid) modify_rev(rid, l, r);
else modify_rev(lid, l, mid), modify_rev(rid, mid+1, r);
pushup(id);
}
seg_tree query(int id, int l, int r){
pushdown(id);
if(tr[id].l == l && tr[id].r == r)
return tr[id];
if(r <= mid) return query(lid, l, r);
else if(l > mid) return query(rid, l, r);
else{
seg_tree t, t1, t2;
t.init(), t1.init(), t2.init();
t1 = query(lid, l, mid);
t2 = query(rid, mid+1, r);
t.sum = t1.sum + t2.sum;
if(t1.lenl[0] == len(lid)) t.lenl[0] = t1.lenl[0] + t2.lenl[0];
else t.lenl[0] = t1.lenl[0];
if(t1.lenl[1] == len(lid)) t.lenl[1] = t1.lenl[1] + t2.lenl[1];
else t.lenl[1] = t1.lenl[1];
if(t2.lenr[0] == len(rid)) t.lenr[0] = t1.lenr[0] + t2.lenr[0];
else t.lenr[0] = t2.lenr[0];
if(t2.lenr[1] == len(rid)) t.lenr[1] = t1.lenr[1] + t2.lenr[1];
else t.lenr[1] = t2.lenr[1];
t.len[0] = max(t1.len[0], max(t2.len[0], t1.lenr[0] + t2.lenl[0]));
t.len[1] = max(t1.len[1], max(t2.len[1], t1.lenr[1] + t2.lenl[1]));
return t;
}
}
int main(){
n = read(), q = read();
for(int i = 1; i <= n; i++) a[i] = read();
build(1, 1, n);
while(q--){
opt = read(), ql = read(), qr = read();
ql++, qr++;
if(opt == 0) modify_cover(1, ql, qr, 0);
else if(opt == 1) modify_cover(1, ql, qr, 1);
else if(opt == 2) modify_rev(1, ql, qr);
else{
seg_tree t = query(1, ql, qr);
if(opt == 3) printf("%d\n", t.sum);
else if(opt == 4) printf("%d\n", t.len[1]);
}
}
return 0;
}