区间合并是一类问题的统称,种类很多,但在这篇博客中只需实现以下操作即可:
有一个01串,你有三种操作:
- 1.将[a, b]中的所有数字改成0
- 2.将[a, b]中的所有数字改成1
- 3.询问[a, b]中最长连续的1的长度是多少
前两种操作其实可以算作一个操作,重点在于如何高效地解决第三种操作。
虽说平衡树也可以解决这类问题,但是这里我们使用线段树来解决。
这是一个经典的老套路
线段树维护四个值(可以缩成三个,使用第四个是为了加强程序的可读性),分别是:
- lsum记录该区间左端点开始的最长连续的值为1区间
- rsum记录该区间右端点开始的最长连续的值为1区间
- sum记录该区间内最长连续的值为1的区间
- color形象解释就是记录区间的“颜色”,具体操作是当这个区间全部是1时color置1,全部为0时color置0,否则置-1。在pushup()的时候会用到。
接下来来讲讲具体的操作
首先是重中之重
这里博主有点懒,就不画图了,相信听了讲解自己脑补一下也是能搞懂的(听起来真的很简单~~)。
开始假设当前节点为
,其左儿子为
,右儿子为
,且
,
中四个值均准确无误,接下来要对
节点做
操作。
分步骤讨论:
- 1、计算
脑补一下,当 的 为1时(也就是说左儿子结点全部由1组成),那么 就是 的 (实际上 , , 的值都是左儿子结点的区间长度,换一下也没有什么大的差别)加上 的 。否则直接赋值为 。 - 2、计算
与计算 的方法类似,当 的 为1时,那么 就是 的 加上 的 。否则直接赋值为 。 - 3、计算
这应该很好想,就直接在 , , 中间取个 就可以了,其中最后一个有点特殊,想想也挺简单,因为 和 中所记录的区间是连续的(看看定义就知道了)。 - 4.计算
有了前面的经验,这个应该很简单,直接给出,脑补也不困难。
其次是
由于本人比较蒟蒻,所以我使用了一个
来存储我需要的
并逐步更新,
里面存储三个值,
,
和
,意义应该很明白,和上面的
,
,
一一对应,答案的转移与
相类似由于上面的
使用到了
来简化操作,现在的
中不便维护
,所以不能偷懒了~。
然后就没有什么可以说的了,其余操作和原版线段树类似,如果不明白可以参考这
struct SegTree {
#define lc(o) o << 1 //简化操作
#define rc(o) o << 1 | 1
#define mid ((l + r) >> 1)
struct Ans {
int ls, rs, s;
Ans(int ls, int rs, int s) : ls(ls), rs(rs), s(s) {}
};
static const int MAXSIZE = 100000 + 5;
int lsum[MAXSIZE << 2], rsum[MAXSIZE << 2], sum[MAXSIZE << 2], color[MAXSIZE << 2];
void creat(int o, int l, int r, int value) { //更新一个结点
color[o] = value;
lsum[o] = rsum[o] = sum[o] = value ? r - l + 1 : 0; //三目运算符秀了一波~
}
void pushdown(int o, int l, int r) {
if (color[o] != -1) { //如果有color那么pushdown,注意color不会在向子结点更新后发生改变
creat(lc(o), l, mid, color[o]);
creat(rc(o), mid + 1, r, color[o]);
}
}
void pushup(int o, int l, int r) { //繁琐的pushup(),具体已经解释过了
lsum[o] = color[lc(o)] == 1 ? lsum[lc(o)] + lsum[rc(o)] : lsum[lc(o)];
rsum[o] = color[rc(o)] == 1 ? rsum[lc(o)] + rsum[rc(o)] : rsum[rc(o)];
sum[o] = max(rsum[lc(o)] + lsum[rc(o)], max(sum[lc(o)], sum[rc(o)]));
if (sum[o] == 0) color[o] = 0; else if (sum[o] == r - l + 1) color[o] = 1; else color[o] = -1;
}
void set(int o, int l, int r, int L, int R, int value) { //和普通线段树一样的操作
if (l > R || r < L) return;
else if (L <= l && r <= R) creat(o, l, r, value);
else {
pushdown(o, l, r);
set(lc(o), l, mid, L, R, value);
set(rc(o), mid + 1, r, L, R, value);
pushup(o, l, r);
}
}
Ans query(int o, int l, int r, int L, int R) {
if (l > R || r < L) return Ans(0, 0, 0);
else if (L <= l && r <= R) return Ans(lsum[o], rsum[o], sum[o]);
else {
pushdown(o, l, r);
if (R <= mid) return query(lc(o), l, mid, L, R);
if (L > mid) return query(rc(o), mid + 1, r, L, R);
Ans p = query(lc(o), l, mid, L, R);
Ans q = query(rc(o), mid + 1, r, L, R);
return Ans(p.ls == mid - l + 1 ? p.ls + q.ls : p.ls,
q.rs == r - mid ? p.rs + q.rs : q.rs,
max(max(p.s, q.s), p.rs + q.ls));
}
}
void build(int o, int l, int r) {
if (l > r) return;
else if (l == r) creat(o, l, r, a[l]);
else {
build(lc(o), l, mid);
build(rc(o), mid + 1, r);
pushup(o, l, r);
}
}
};