题目描述
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:翻转一个区间,例如原有序序列是5 4 3 2 1,翻转区间是[2,4]的话,结果是5 2 3 4 1
分析
不懂splay可以看一下我的博客:【传送门】
这道题目就是用splay来实现区间反转的,这个东西听说好像是LCT用splay的原因,我也不清楚没有学过LCT。
很明显,我们这道题目维护的不是权值,而是区间的编号(虽然好像还是权值),那么翻转操作就是交换两个子树的儿子的关系,但是如果每一次都暴力翻转\(O(mlog^2n)\),就做一个懒标记。
如果一个区间被旋转了两次,那么很明显,这个区间又变回去了,那么我们就维护一个标记表示表示以下的区间是否被翻转过。
那么剩下来的答案其实就是二叉查找树的中序遍历(BST的性质)。
ac代码
#include <bits/stdc++.h>
#define ll long long
#define ms(a, b) memset(a, b, sizeof(a))
#define inf 0x3f3f3f3f
#define N 100005
using namespace std;
template <typename T>
inline void read(T &x) {
x = 0; T fl = 1;
char ch = 0;
while (ch < '0' || ch > '9') {
if (ch == '-') fl = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
x = (x << 1) + (x << 3) + (ch ^ 48);
ch = getchar();
}
x *= fl;
}
struct Splay {
int rt, tot;
struct node {
int ch[2], fa, val, sz, fg;
void init(int nod, int ft) {
fa = ft;
ch[0] = ch[1] = 0;
sz = 1;
val = nod;
}
}tr[N << 1];
Splay() {
ms(tr, 0);
rt = tot = 0;
}
void pushup(int nod) {
tr[nod].sz = tr[tr[nod].ch[0]].sz + tr[tr[nod].ch[1]].sz + 1;
}
void pushdown(int nod) {
if (!tr[nod].fg) return;
tr[tr[nod].ch[0]].fg ^= 1;
tr[tr[nod].ch[1]].fg ^= 1;
tr[nod].fg = 0;
swap(tr[nod].ch[0], tr[nod].ch[1]);
}
void rotate(int nod) {
int fa = tr[nod].fa, gf = tr[fa].fa, k = tr[fa].ch[1] == nod;
tr[gf].ch[tr[gf].ch[1] == fa] = nod;
tr[nod].fa = gf;
tr[fa].ch[k] = tr[nod].ch[k ^ 1];
tr[tr[nod].ch[k ^ 1]].fa = fa;
tr[nod].ch[k ^ 1] = fa;
tr[fa].fa = nod;
pushup(fa);
pushup(nod);
}
void splay(int nod, int goal) {
while (tr[nod].fa != goal) {
int fa = tr[nod].fa, gf = tr[fa].fa;
if (gf != goal) {
if ((tr[gf].ch[0] == fa) ^ (tr[fa].ch[0] == nod)) rotate(nod);
else rotate(fa);
}
rotate(nod);
}
if (goal == 0) rt = nod;
}
int kth(int k) {
int u = rt;
while (1) {
pushdown(u);
int lc = tr[u].ch[0];
if (tr[lc].sz >= k) u = lc;
else if (tr[lc].sz + 1 == k) return u;
else k -= tr[lc].sz + 1, u = tr[u].ch[1];
}
}
void insert(int x) {
int u = rt, ft = 0;
while (u) {
ft = u;
u = tr[u].ch[x > tr[u].val];
}
u = ++ tot;
if (ft) tr[ft].ch[x > tr[ft].val] = u;
tr[u].init(x, ft);
splay(u, 0);
}
void solve(int l, int r) {
l = kth(l);
r = kth(r + 2);
splay(l, 0);
splay(r, l);
tr[tr[tr[rt].ch[1]].ch[0]].fg ^= 1;
}
}splay;
int n, m;
void dfs(int nod) {
splay.pushdown(nod);
if (splay.tr[nod].ch[0]) dfs(splay.tr[nod].ch[0]);
if (splay.tr[nod].val >= 2 && splay.tr[nod].val <= n + 1) printf("%d ", splay.tr[nod].val - 1);
if (splay.tr[nod].ch[1]) dfs(splay.tr[nod].ch[1]);
}
int main() {
read(n); read(m);
for (int i = 1; i <= n + 2; i ++) splay.insert(i);
while (m --) {
int l, r;
read(l); read(r);
splay.solve(l, r);
}
dfs(splay.rt);
return 0;
}