Codeforces Round #742E Non-Decreasing Dilemma(线段树)

题目大意:
给一个数组,有两种操作
1:改变数组中的一个值
2: 求区间中成非递减的子序列有多少个
刚开始拿到题感觉:这不就是个板子吗?
然后这道板子题就卡了我两天,直到看了别人的代码才知道错哪儿了
节点需要维护的信息:
左右区间范围
当前区间中,合法子序列的数量
包含左端点的最长合法子序列的长度(llen)
包含右端点的最长合法子序列的长度(rlen)
让我们合并两个区间的时候,需要分情况讨论
如果左儿子最右端的数大于右儿子最左端的数:
父区间的合法子序列的数量就是两个子区间合法子序列的数量
如果左儿子最右端的数小于或等于右儿子最左端的数
此时父亲区间的的合法子序列的数量处理时两儿子的合之外,还要加上由于区间合并,多出来的合法子序列,其多出来的数量就是左儿子的rlen乘上右儿子的llen
此外,也可以用来以下代码来计算

int get(int x) {
    
    //计算:长度为x的区间中有多少个子区间
	return (x * (x + 1) / 2);
}
int a = get(tr[u << 1].rlen + tr[u << 1 | 1].llen);
int b = get(tr[u << 1].rlen);
int c = get(tr[u << 1 | 1].llen);
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum + (a - b - c);

题中一个重要的易错点就是query函数
当我们递归了两个子儿子后,如果左儿子最右端的数小于右儿子最左端的数,就还要考虑两儿子组合后多出来的结果,似乎套用上面的公式就行。
但是!
如果我们直接套用会导致:就算问询区间中不包含左儿子,我们也没有递归左儿子,也会导致在计算时额外多加了组合时多出来的数
这是最开始的写法

if (w[tr[u << 1].r] <= w[tr[u << 1 | 1].l]) {
    
    
			int a = get(tr[u << 1].rlen + tr[u << 1 | 1].llen);
			int b = get(tr[u << 1].rlen);
			int c = get(tr[u << 1 | 1].llen);
			//cout << a << " " << b << " " << c << endl;
			res += (a - b - c);
}

如果我们在if上面添加一个限制条件:mid >= l && mid < r
会导致:我们递归了一个左儿子的一个右儿子,但在计算时会把整个左儿子的rlen拿过来算(实际上只需要算左儿子的右儿子的rlen)
所以用这个代码才能避免

if (w[tr[u << 1].r] <= w[tr[u << 1 | 1].l] ) {
    
    
		int lsum = min(mid - l + 1, tr[u << 1].rlen);
		int rsum = min(r - mid, tr[u << 1 | 1].llen);
		if (lsum > 0 && rsum > 0)
			res += lsum * rsum;
	}
	return res;

以下是全部代码:

#include <bits/stdc++.h>
#define int long long
using namespace std;
int n, q;
const int N = 2e5 + 10;
int w[N];

struct node {
    
     //计算:区间中最大的连续子串长度 与中点向连的最大连续子串长度
	int l, r;
	int sum;//区间总个数
	int llen, rlen;//以边界的最长单调
} tr[4 * N];

int get(int x) {
    
    //计算:长度为x的区间中有多少个子区间
	return (x * (x + 1) / 2);
}

void pushup(int u) {
    
    
	tr[u].llen = tr[u << 1].llen;
	tr[u].rlen = tr[u << 1 | 1].rlen;
	if (w[tr[u << 1].r] > w[tr[u << 1 | 1].l]) {
    
    
		tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
	} else {
    
    
		int a = get(tr[u << 1].rlen + tr[u << 1 | 1].llen);
		int b = get(tr[u << 1].rlen);
		int c = get(tr[u << 1 | 1].llen);
		tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum + (a - b - c);
		if (tr[u << 1].llen == tr[u << 1].r - tr[u << 1].l + 1) {
    
    
			tr[u].llen = tr[u << 1].llen + tr[u << 1 | 1].llen;
		}
		if (tr[u << 1 | 1].rlen == tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) {
    
    
			tr[u].rlen = tr[u << 1 | 1].rlen + tr[u << 1].rlen;
		}
	}
}

void build(int u, int l, int r) {
    
    
	tr[u] = {
    
    l, r, 1, 1, 1};
	if (l == r)
		return;
	int mid = l + r >> 1;
	build(u << 1, l, mid);
	build(u << 1 | 1, mid + 1, r);
	pushup(u);
}

void modify(int u, int x, int v) {
    
    
	if (tr[u].r == x && tr[u].l == x) {
    
    
		w[x] = v;
		return;
	}
	int mid = tr[u].l + tr[u].r >> 1;
	if (mid >= x)
		modify(u << 1, x, v);
	else
		modify(u << 1 | 1, x, v);
	pushup(u);
}

int query(int u, int l, int r) {
    
    
	if (tr[u].l >= l && tr[u].r <= r) {
    
    
		return tr[u].sum;
	}
	int mid = tr[u].l + tr[u].r >> 1;
	int res = 0;
	if (mid >= l)
		res = query(u << 1, l, r);
	if (mid < r)
		res += query(u << 1 | 1, l, r);
	/*
	这里是一个错误的写法
	if (w[tr[u << 1].r] <= w[tr[u << 1 | 1].l] && mid >= l && mid < r) {
			int a = get(tr[u << 1].rlen + tr[u << 1 | 1].llen);
			int b = get(tr[u << 1].rlen);
			int c = get(tr[u << 1 | 1].llen);
			//cout << a << " " << b << " " << c << endl;
			res += (a - b - c);
		}
	*/
	if (w[tr[u << 1].r] <= w[tr[u << 1 | 1].l] ) {
    
    
		int lsum = min(mid - l + 1, tr[u << 1].rlen);
		int rsum = min(r - mid, tr[u << 1 | 1].llen);
		if (lsum > 0 && rsum > 0)
			res += lsum * rsum;
	}
	return res;
}

void solve() {
    
    
	cin >> n >> q;
	for (int i = 1; i <= n; i++)
		cin >> w[i];
	build(1, 1, n);
	while (q--) {
    
    

		int a, b, c;
		cin >> a >> b >> c;
		if (a == 1) {
    
    
			modify(1, b, c);
		} else {
    
    
			cout << query(1, b, c) << endl;
		}
	}

}

signed main() {
    
    

	ios::sync_with_stdio(false);
	solve();


}




おすすめ

転載: blog.csdn.net/fdxgcw/article/details/120166929