【BZOJ4262】Sum(离线+线段树/可持久化线段树)


  • 挺不错的数据结构题。
  • 题目大意就是给出 Q ( Q 40000 ) Q(Q\le 40000) 个询问,给定参数 l 1 , r 1 , l 2 , r 2 l_1,r_1,l_2,r_2 ,询问长度 n = 100000 n=100000 的序列中, l [ l 1 , r 1 ] r [ l 2 , r 2 ] ( max i [ l , r ] a i min i [ l , r ] a i ) \sum\limits_{l\in[l_1,r_1]}\sum\limits_{r\in[l_2, r_2]}(\max\limits_{i\in[l,r]}a_i-\min\limits_{i\in[l,r]}a_i) 的值,允许离线。

  • 我们可以把 r [ l 2 , r 2 ] r\in[l_2,r_2] 这个条件转化为两个前缀信息相减的形式,即 r [ 1 , r 2 ] l [ l 1 , r 1 ] ( max i [ l , r ] a i min i [ l , r ] a i ) r [ 1 , l 2 1 ] l [ l 1 , r 1 ] ( max i [ l , r ] a i min i [ l , r ] a i ) \sum\limits_{r\in[1, r_2]}\sum\limits_{l\in[l_1,r_1]}(\max\limits_{i\in[l,r]}a_i-\min\limits_{i\in[l,r]}a_i)-\sum\limits_{r\in[1, l_2-1]}\sum\limits_{l\in[l_1,r_1]}(\max\limits_{i\in[l,r]}a_i-\min\limits_{i\in[l,r]}a_i)
  • 然后最大最小值可以分开处理,最小值只要取反一下,就可以套用最大值的过程。
  • 把询问离线,按顺序枚举 r r ,我们就能把询问转化为 r 0 = 1 r l [ l 1 , r 1 ] max i [ l , r 0 ] a i \sum\limits_{r_0=1}^r\sum\limits_{l\in[l_1,r_1]}\max\limits_{i\in[l,r_0]}a_i 的形式。
  • 这样总的询问数仍然是 O ( Q ) O(Q) 级别的。
  • 我们可以考虑动态维护一个线段树,以位置为下标,然后每个位置维护这个位置到目前枚举的 r r 这个区间的最大值。对于每个这样的询问,直接在线段树上查询。即离线线段树。
  • 然后考虑我们在右端点可选集合 [ 1 , i ) [1,i) ,中加一个 i i ,线段树会有什么变化:
  • 用单调栈求出一个最小的 k k 使得 max j [ k , i ] a j = a i \max\limits_{j\in[k,i]}a_j=a_i 。那么我们就要把 [ k , i ] [k,i] 中的最大值全部改为 a i a_i
  • 这样我们能够资瓷查询 l [ l 1 , r 1 ] max i [ l , r ] a i \sum\limits_{l\in[l_1,r_1]}\max\limits_{i\in[l,r]}a_i

  • 但是题目要求我们对历史版本求和,这时候我们就要多记录一些关于历史版本的信息,并且用下传标记实现(因为这个标记不满足交换律)。
  • 对每个线段树结点 x x 维护三个信息:
    • v a l : val: 区间内所有当前版本的最大值之和。
    • t i m : tim: 区间内所有位置中,最晚一个修改到当前版本的那个修改的时间戳。
    • s u m : sum: 该区间的 1 t i m 1 1\to tim-1 所有版本 v a l val 之和,每个版本都要各贡献一次。
  • 对每个标记的记录四个信息。
    • v a l : val: 当前要把被标记的区间的所有最大值修改为的最新版本 v a l val
    • t l : tl: 这个标记最早的修改的时间戳为 t l tl
    • t r : tr: 这个标记最新的修改的时间戳为 t r tr
    • s u m : sum: 这个标记 t l t r 1 tl\to tr-1 所有版本 v a l val 之和,每个版本都要各贡献一次。
  • 可以把每个修改当做往对应区间扔一个标记,实现起来会比较简洁。
  • 关于合并标记以及合并区间的操作实现,可以看代码。
  • 时间复杂度 O ( Q log n ) O(Q\log n)
  • 这题如果强制在线,可以用可持久化线段树维护前缀信息,在线回答询问,时间复杂度相同,空间复杂度要多一个 log \log 。由于作者太懒,在这里就不多说了。
#include <bits/stdc++.h>

inline char nextChar()
{
	static const int buffer_size = 2333333; 
	static char buffer[buffer_size]; 
	static const char *tail = buffer + buffer_size; 
	static char *head = buffer + buffer_size; 
	
	if (head == tail)
	{
		fread(buffer, 1, buffer_size, stdin); 
		head = buffer; 
	}
	return *head++; 
}

inline void putChar(char ch)
{
	static const int buffer_size = 2333333; 
	static char buffer[buffer_size]; 
	static const char *tail = buffer + buffer_size; 
	static char *head = buffer; 
	
	if (ch == '\0')
		fwrite(buffer, 1, head - buffer, stdout); 
	
	*head++ = ch; 
	if (head == tail)
		fwrite(buffer, 1, buffer_size, stdout); 
}

template <class T>
inline void putint(T x)
{
	static char buf[22]; 
	static char *tail = buf; 
	if (!x) return (void)(putChar('0')); 
	if (x < 0) x = ~x + 1, putChar('-'); 
	for (; x; x /= 10) *++tail = x % 10 + '0'; 
	for (; tail != buf; --tail) putChar(*tail); 
}

template <class T>
inline void read(T &x)
{
	static char ch; 
	while (!isdigit(ch = nextChar())); 
	x = ch - '0'; 
	while (isdigit(ch = nextChar()))
		x = x * 10 + ch - '0'; 
}

template <class T>
inline void relax(T &x, const T &y)
{
	if (x < y) x = y; 
}

const int MaxN = 1e5 + 5; 
const int MaxS = MaxN << 2; 

typedef long long s64; 

struct request
{
	int l, r, opt, num; 
	request(){} request(int a, int b, int c, int m):
		l(a), r(b), opt(c), num(m) {}
}; 

struct tags
{
	int val, tl, tr; s64 sum; 
	tags(){} tags(int a, s64 s, int b, int c):
		val(a), sum(s), tl(b), tr(c) {}
	inline tags operator + (const tags &rhs) const
	{
		if (val == -1) return rhs; 
		return tags(rhs.val, sum + 1LL * (rhs.tl - tr) * val + rhs.sum, tl, rhs.tr); 
	}
	inline void clear()
	{
		val = -1; 
	}
	inline bool empty()
	{
		return val == -1; 
	}
}tag[MaxS];

struct info
{
	s64 val, sum; int tim, len; 
	info(){} info(s64 v, s64 s, int t, int l):
		val(v), sum(s), tim(t), len(l) {}
	inline info operator + (const info &rhs) const
	{
		int t = std::max(tim, rhs.tim); 
		s64 s = sum + rhs.sum + (t - tim) * val + (t - rhs.tim) * rhs.val; 
		s64 v = val + rhs.val; 
		return info(v, s, t, len + rhs.len); 
	}
	inline info operator + (const tags &rhs) const
	{
		return info(1LL * rhs.val * len, sum + val * (rhs.tl - tim) + rhs.sum * len, rhs.tr, len); 
	}
	inline void clear()
	{
		val = sum = tim = 0, len = 1; 
	}
}seg[MaxS], res; 

int n, Q, a[MaxN]; 
s64 ans[MaxN]; 
std::vector<request> req[MaxN]; 

#define lcc x << 1, l, mid
#define rcc x << 1 | 1, mid + 1, r

inline void upt(int x)
{
	seg[x] = seg[x << 1] + seg[x << 1 | 1]; 
}

inline void modify_node(int x, const tags &del)
{
	seg[x] = seg[x] + del; 
	tag[x] = tag[x] + del; 
}

inline void dnt(int x)
{
	if (!tag[x].empty())
	{
		modify_node(x << 1, tag[x]); 
		modify_node(x << 1 | 1, tag[x]); 
		tag[x].clear(); 
	}
}

inline void build(int x, int l, int r)
{
	tag[x].clear(); 
	if (l == r) return (void)(seg[x].clear()); 
	int mid = l + r >> 1; 
	build(lcc), build(rcc); 
	upt(x); 
}

inline void query(int x, int l, int r, int u, int v)
{
	if (u > v)
		return; 
	if (u <= l && r <= v)
		return (void)(res = res + seg[x]); 
	dnt(x); 
	
	int mid = l + r >> 1; 
	if (u <= mid)
		query(lcc, u, v); 
	if (v > mid)
		query(rcc, u, v); 
}

inline void modify(int x, int l, int r, int u, int v, tags del)
{
	if (u <= l && r <= v)
		return (void)(modify_node(x, del)); 
	dnt(x); 
	
	int mid = l + r >> 1; 
	if (u <= mid)
		modify(lcc, u, v, del); 
	if (v > mid)
		modify(rcc, u, v, del); 
	upt(x); 
}

inline void solve()
{
	build(1, 1, n); 
	static int stk[MaxN], top; 
	top = 0; 
	
	for (int i = 1; i <= n; ++i)
	{
		while (top && a[stk[top]] < a[i]) --top; 
		
		modify(1, 1, n, stk[top] + 1, i, tags(a[i], 0, i, i)); 
		stk[++top] = i; 
		
		for (int j = 0, jm = req[i].size(); j < jm; ++j)
		{
			request now = req[i][j]; 
			res.clear(); 
			query(1, 1, n, now.l, now.r); 
			ans[now.num] += (res.val * (i - res.tim + 1) + res.sum) * now.opt; 
		}
	}
}

int main()
{
	read(Q); 
	for (int i = 1; i <= Q; ++i)
	{
		int l1, l2, r1, r2; 
		read(l1), read(r1), read(l2), read(r2); 
		relax(n, r1), relax(n, r2); 
		
		if (l2 > 1)
			req[l2 - 1].push_back(request(l1, r1, -1, i)); 
		req[r2].push_back(request(l1, r1, 1, i)); 
	}
	
	static const int mod = 1e9; 
	for (int i = 1, t1 = 1, t2 = 1; i <= n; ++i)
	{
		t1 = 1LL * t1 * 1023 % mod; 
		t2 = 1LL * t2 * 1025 % mod; 
		a[i] = t1 ^ t2; 
	}
	
	solve(); 
	for (int i = 1; i <= n; ++i) a[i] = ~a[i] + 1; 
	solve(); 
	
	for (int i = 1; i <= Q; ++i)
		putint(ans[i]), putChar('\n'); 
	putChar('\0'); 
	return 0; 
} 

猜你喜欢

转载自blog.csdn.net/qq_35811706/article/details/84589957