https://www.luogu.org/problemnew/show/P5055
如果你还不会可持久化平衡树,请右转洛谷题解区
我们考虑一下可持久化的复杂度为什么是正确的
每一次操作时,$ fhq $ $ treap $ 的复杂度是期望 $ log ( n ) $ 的,而 $ Splay $ 是均摊 $ log ( n ) $ 的,这样访问节点的总次数大概是 $ n log n $,保证了时间复杂度和空间复杂度的正确
我们每次操作的时候,遇到一个节点就复制一个节点,$ pushdown $ 操作的时候把两个孩子复制一下,再打上 $ reverse $ 标记,旋转的时候也把节点复制一下,不然这次操作会影响到以前的版本,前面说明了遇到的节点个数总数是 $ n log n $,所以这样是正确的
开始写代码的时候本人 $ zz $ 了,没写 $ Splay $,导致自己调了很久,代码也不太好看,建议大家写 $ fhq $ $ treap $ 或 $ Splay $
#include <bits/stdc++.h>
#define CIOS ios::sync_with_stdio(false);
#define For(i, a, b) for(register int i = a; i <= b; i++)
#define Rof(i, a, b) for(register int i = a; i >= b; i--)
#define DEBUG(x) cerr << "DEBUG" << x << " >>> ";
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
template <typename T>
inline void read(T &f) {
f = 0; T fu = 1; char c = getchar();
while(c < '0' || c > '9') {if(c == '-') fu = -1; c = getchar();}
while(c >= '0' && c <= '9') {f = (f << 3) + (f << 1) + (c & 15); c = getchar();}
f *= fu;
}
template <typename T>
void print(T x) {
if(x < 0) putchar('-'), x = -x;
if(x < 10) putchar(x + 48);
else print(x / 10), putchar(x % 10 + 48);
}
template <typename T>
void print(T x, char t) {
print(x); putchar(t);
}
const int N = 2e5 + 5, INF = 0x7fffffff;
#define new_Node(a, c, d, e, f) (&(t[cnt++] = Node(a, c, d, e, f)))
#define merge(a, b) new_Node(a -> size + b -> size, 0, a -> sum + b -> sum, a, b)
struct Node {
int size; bool rev; ll sum;
Node *left, *right;
Node () {}
Node (int a, bool c, ll sum, Node *d, Node *e) : size(a), rev(c), sum(sum), left(d), right(e) {}
}*rt[N], *null, t[N * 80];
ll lastans; int n, cnt;
inline void update(Node *u) { if(u -> left -> size) u -> size = u -> left -> size + u -> right -> size, u -> sum = u -> left -> sum + u -> right -> sum; }
Node *pushdown(Node *u) {
if(u -> rev) {
Node *cur = new_Node(u -> size, u -> rev, u -> sum, u -> left, u -> right);
swap(cur -> left, cur -> right);
Node *l = new_Node(cur -> left -> size, cur -> left -> rev, cur -> left -> sum, cur -> left -> left, cur -> left -> right);
Node *r = new_Node(cur -> right -> size, cur -> right -> rev, cur -> right -> sum, cur -> right -> left, cur -> right -> right);
cur -> left = l; cur -> right = r;
if(cur -> left != null) cur -> left -> rev ^= 1;
if(cur -> right != null) cur -> right -> rev ^= 1;
cur -> rev = 0;
return cur;
} return u;
}
Node *ins(Node *u, ll x, ll y) {
Node *cur = new_Node(u -> size, u -> rev, u -> sum, u -> left, u -> right); cur = pushdown(cur);
if(cur -> size == 1) cur -> left = new_Node(1, 0, cur -> sum, null, null), cur -> right = new_Node(1, 0, y, null, null);
else if(x > cur -> left -> size) cur -> right = ins(cur -> right, x - cur -> left -> size, y);
else cur -> left = ins(cur -> left, x, y); update(cur); return cur;
}
Node *erase(Node *u, ll x) {
Node *cur = new_Node(u -> size, u -> rev, u -> sum, u -> left, u -> right); cur = pushdown(cur);
if(cur -> left -> size == 1 && x == 1) return cur = cur -> right;
else if(cur -> right -> size == 1 && x == cur -> left -> size + 1) return cur = cur -> left;
if(x > cur -> left -> size) cur -> right = erase(cur -> right, x - cur -> left -> size);
else cur -> left = erase(cur -> left, x); update(cur); return cur;
}
Node *split(Node *u, ll x) {
Node *cur = new_Node(u -> size, u -> rev, u -> sum, u -> left, u -> right); cur = pushdown(cur);
if(x > cur -> left -> size) return cur -> right = split(cur -> right, x - cur -> left -> size), cur -> left = merge(cur -> left, cur -> right -> left), cur -> right = cur -> right -> right, update(cur), cur;
else if(x < cur -> left -> size) return cur -> left = split(cur -> left, x), cur -> right = merge(cur -> left -> right, cur -> right), cur -> left = cur -> left -> left, update(cur), cur;
else return cur;
}
ll query(int root, ll l, ll r) {
rt[root] = split(rt[root], r + 1);
rt[root] -> left = split(rt[root] -> left, l);
Node *cur = new_Node(rt[root] -> left -> right -> size, rt[root] -> left -> right -> rev, rt[root] -> left -> right -> sum, rt[root] -> left -> right -> left, rt[root] -> left -> right -> right); cur = pushdown(cur);
rt[root] -> left -> right = cur; return cur -> sum;
}
int main() {
null = new Node(0, 0, 0, 0, 0);
rt[0] = new Node(1, 0, INF, null, null); rt[0] = ins(rt[0], 1, INF); read(n);
for(register int i = 1; i <= n; i++) {
int pre, opt; read(pre); read(opt);
if(opt == 1) {
ll x, y; read(x); read(y); x ^= lastans; y ^= lastans;
rt[i] = ins(rt[pre], x + 1, y);
}
if(opt == 2) {
ll x; read(x); x ^= lastans;
rt[i] = erase(rt[pre], x + 1);
}
if(opt == 3) {
ll l, r; read(l); read(r); l ^= lastans; r ^= lastans;
rt[i] = new_Node(rt[pre] -> size, rt[pre] -> rev, rt[pre] -> sum, rt[pre] -> left, rt[pre] -> right);
if(l == r) continue; rt[i] = split(rt[i], r + 1); rt[i] -> left = split(rt[i] -> left, l);
Node *cur = new_Node(rt[i] -> left -> right -> size, rt[i] -> left -> right -> rev, rt[i] -> left -> right -> sum, rt[i] -> left -> right -> left, rt[i] -> left -> right -> right);
rt[i] -> left -> right = cur; rt[i] -> left -> right -> rev ^= 1;
}
if(opt == 4) {
ll l, r; read(l); read(r); l ^= lastans; r ^= lastans; print(lastans = query(pre, l, r), '\n');
rt[i] = new_Node(rt[pre] -> size, rt[pre] -> rev, rt[pre] -> sum, rt[pre] -> left, rt[pre] -> right);
}
}
return 0;
}