Can you answer these queries?
GSS系列是spoj出品的一套数据结构好毒瘤题,主要以线段树、平衡树和树链剖分为背景,进行了一些操作的魔改,使得难度远超模板题,但对于思维有极大的提升。
所以我会选择一些在我能力范围内的题挖坑选讲,构成一个GSS系列。至于剩下那些,等我成为巨佬弄懂了再说吧。
GSS5:限制端点范围的区间最大子段和
前置芝士:GSS1:区间最大子段和。(题解:CSDN, 个人博客)
题意
给定一个序列,查询左端点在 ,右端点在 的所有区间的最大子段和的最大值。
口胡
这题的特点是不直接给定区间,而是给定区间的范围,但由于限制的是端点位置,那么其实可以通过一些分类讨论,把问题转化为最简单的最大子段和。
首先先按给定的两个集合之间的关系分类:
-
给定的区间无重叠。那么显然区间 是必选的,那么我们只要在 中求出 , 中求出 ,再合并一下即可。
-
给定的区间有重叠。这种情况比较复杂,需要再进行分类统计答案主要分为
- 取重叠部分的
- 取重叠部分前的 和重叠部分的
- 取重叠部分后的 和重叠部分的
- 取重叠部分的 并与前后合并
其实2,3,4的求解方法是一样的,只需要稍作改动即可。(4其实也可以转换为2,3来求)
具体可以结合图解即代码理解。
图解
由于情况2,3,4几乎是一样的,于是就画成了相同的颜色。其实是找不出颜色了
对于情况1,也就是区间端点在重叠部分,那么显然是图中红色线段的形式,即重叠部分的 。
对于情况2,3,4,一定是某一条分割线前的 加上分割线后的 合并而来。这也是为什么在算2,3的时候可以同时把4算了。
如上图,假如我们正在算情况二,那么我们本来要算的是 的 加上 的 ,但我们发现要是把 改为 ,那么就是情况4的答案,但由于这两个区间是存在包含关系的,那么显然在算 的时候会计算到 。也就是说我们可以同时算出情况2和情况4(部分)的答案。那么情况三的时候也是同理,2,3拼起来,就计算出了4的完整答案。
核心代码:
ll ans = query(1,1,n, l2, r1).mx;
if(l1 < l2) ans = max(ans, query(1,1,n, l1, l2).rs+max(query(1,1,n, l2+1, r2).ls, 0ll));
if(r2 > r1) ans = max(ans, query(1,1,n, l1, r1).rs+max(query(1,1,n, r1+1, r2).ls, 0ll));
这段代码中的两个if
,特判了两个区间完全相等的情况。
完整代码
注意区间的边界问题,防止重复计算
#include <bits/stdc++.h>
#define ll long long
#define MAX 10005
#define lc(p) (p<<1)
#define rc(p) (p<<1|1)
#define mid ((l+r)>>1)
using namespace std;
int n, m;
struct node {
ll sum, ls, rs, mx;
node() {
sum = ls = rs = mx = 0;
}
} s[MAX*4];
ll a[MAX];
inline node merge(node a, node b) {
node res;
res.sum = a.sum + b.sum;
res.ls = max(a.ls, a.sum+b.ls);
res.rs = max(b.rs, b.sum+a.rs);
res.mx = max(max(a.mx, b.mx), a.rs+b.ls);
return res;
}
inline void push_up(int p) {
s[p] = merge(s[lc(p)], s[rc(p)]);
}
void build(int p, int l, int r) {
if(l == r) {
s[p].sum = s[p].ls = s[p].rs = s[p].mx = a[l];
return;
}
build(lc(p), l, mid);
build(rc(p), mid+1, r);
push_up(p);
}
node query(int p, int l, int r, int ul, int ur) {
if(ul > ur) return s[0];
if(l>=ul && r<=ur) {
return s[p];
}
if(mid < ul) {
return query(rc(p), mid+1, r, ul, ur);
} else if(mid >= ur) {
return query(lc(p), l, mid, ul, ur);
} else {
node t1 = query(lc(p), l, mid, ul, ur);
node t2 = query(rc(p), mid+1, r, ul, ur);
return merge(t1, t2);
}
}
ll get_ans(int l1, int r1, int l2, int r2) {
if(r1 < l2) {
ll l = query(1,1,n, l1, r1).rs;
ll r = query(1,1,n, l2, r2).ls;
return l+r+query(1,1,n, r1+1, l2-1).sum;
}
ll ans = query(1,1,n, l2, r1).mx;
if(l1 < l2) ans = max(ans, query(1,1,n, l1, l2).rs+max(query(1,1,n, l2+1, r2).ls, 0ll));
if(r2 > r1) ans = max(ans, query(1,1,n, l1, r1).rs+max(query(1,1,n, r1+1, r2).ls, 0ll));
return ans;
}
void solve() {
memset(s, 0, sizeof(s));
cin >> n;
for(int i = 1; i <= n; i++) {
scanf("%lld", &a[i]);
}
build(1, 1, n);
cin >> m;
int l1, r1, l2, r2;
while(m--) {
scanf("%d%d%d%d", &l1, &r1, &l2, &r2);
printf("%lld\n", get_ans(l1,r1,l2,r2));
}
}
int main() {
int t;
cin >> t;
while(t--) {
solve();
}
return 0;
}