emm,听说数组常数小,就去打了一遍数组版的Treap模板(之前都写的指针),但是发现指针明明比数组快56ms好伐……QAQ
数组版 380ms
#include <cstdio>
#include <cstdlib>
const int N = 100005, INF = 0x3f3f3f3f;
int siz[N], num[N], val[N], rnd[N], ch[N][2], root, cnt;
int read() {
int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + (c ^ 48);
c = getchar();
}
return x * f;
}
int min(int x, int y) {
if (x <= y) return x; return y;
}
int max(int x, int y) {
if (x >= y) return x; return y;
}
void newNode(int &cur, int x) {
cur = ++cnt;
siz[cur] = num[cur] = 1;
val[cur] = x, rnd[cur] = rand();
}
int cmp(int cur, int x) {
if (val[cur] == x) return -1;
return x < val[cur] ? 0 : 1;
}
void maintain(int &cur) {
siz[cur] = num[cur] + siz[ch[cur][0]] + siz[ch[cur][1]];
}
void rotate(int &cur, int k) {
int fa = ch[cur][k^1];
ch[cur][k^1] = ch[fa][k], ch[fa][k] = cur;
maintain(cur), maintain(fa), cur = fa;
}
void insert(int &cur, int x) {
if (!cur) newNode(cur, x);
else {
int k = cmp(cur, x);
if (k == -1) ++num[cur];
else {
insert(ch[cur][k], x);
if (rnd[ch[cur][k]] > rnd[cur]) rotate(cur, k ^ 1);
}
maintain(cur);
}
}
void del(int &cur, int x) {
if (!cur) return;
int k = cmp(cur, x);
if (k == -1) {
if (num[cur] > 1) --num[cur];
else {
if (!ch[cur][0]) cur = ch[cur][1];
else if (!ch[cur][1]) cur = ch[cur][0];
else {
int kk = rnd[ch[cur][0]] < rnd[ch[cur][1]] ? 0 : 1;
rotate(cur, kk), del(ch[cur][kk], x);
}
}
} else del(ch[cur][k], x);
maintain(cur);
}
int rank(int cur, int x) {
if (!cur) return 1;
if (val[cur] == x) return siz[ch[cur][0]] + 1;
if (val[cur] > x) return rank(ch[cur][0], x);
return rank(ch[cur][1], x) + siz[ch[cur][0]] + num[cur];
}
int kth(int cur, int x) {
if (!cur || siz[cur] < x || x <= 0) return 0;
if (siz[ch[cur][0]] + 1 <= x && x <= siz[ch[cur][0]] + num[cur]) return val[cur];
if (siz[ch[cur][0]] >= x) return kth(ch[cur][0], x);
return kth(ch[cur][1], x - siz[ch[cur][0]] - num[cur]);
}
int pre(int cur, int x) {
if (!cur) return -INF;
if (val[cur] >= x) return pre(ch[cur][0], x);
return max(val[cur], pre(ch[cur][1], x));
}
int suf(int cur, int x) {
if (!cur) return INF;
if (val[cur] <= x) return suf(ch[cur][1], x);
return min(val[cur], suf(ch[cur][0], x));
}
int main() {
int n = read();
while (n--) {
int opt = read(), x = read();
if (opt == 1) insert(root, x);
else if (opt == 2) del(root, x);
else if (opt == 3) printf("%d\n", rank(root, x));
else if (opt == 4) printf("%d\n", kth(root, x));
else if (opt == 5) printf("%d\n", pre(root, x));
else printf("%d\n", suf(root, x));
}
return 0;
}
指针版 324ms
#include <cstdio>
#include <string>
#include <cstdlib>
const int INF = 0x3f3f3f3f;
int read() {
int x = 0, f = 1;
char c = getchar();
while (!isdigit(c)) {
if (c == '-') f = -1;
c = getchar();
}
while (isdigit(c)) {
x = (x << 3) + (x << 1) + (c ^ 48);
c = getchar();
}
return x * f;
}
int min(int x, int y) { if (x <= y) return x; return y; }
int max(int x, int y) { if (x >= y) return x; return y; }
struct Node {
Node *ch[2];
int v, r, s, n;
int cmp(int x) {
if (x == v) return -1;
return (x < v ? 0 : 1);
}
void in(int x) {
v = x, r = rand(), s = n = 1, ch[0] = ch[1] = NULL;
}
void maintain() {
s = n;
if (ch[0]) s += ch[0]->s;
if (ch[1]) s += ch[1]->s;
}
} *root, Pool[100005];
Node *newNode() {
static int cnt = 0;
return &Pool[cnt++];
}
void rotate(Node *&cur, int d) {
Node *k = cur->ch[d^1];
cur->ch[d^1] = k->ch[d], k->ch[d] = cur;
cur->maintain(), k->maintain(), cur = k;
}
void insert(Node *&cur, int x) {
if (!cur) cur = newNode(), cur->in(x);
else {
int d = cur->cmp(x);
if (d == -1) ++cur->n;
else {
insert(cur->ch[d], x);
if (cur->ch[d]->r > cur->r) rotate(cur, d ^ 1);
}
cur->maintain();
}
}
void remove(Node *&cur, int x) {
if (!cur) return;
int d = cur->cmp(x);
if (d == -1) {
if (cur->n > 1) --cur->n;
else {
if (!cur->ch[0]) cur = cur->ch[1];
else if (!cur->ch[1]) cur = cur->ch[0];
else {
int d2 = (cur->ch[0]->r < cur->ch[1]->r ? 0 : 1);
rotate(cur, d2), remove(cur->ch[d2], x);
}
}
} else remove(cur->ch[d], x);
if (cur) cur->maintain();
}
int rank(Node *cur, int x) {
if (!cur) return 1;
if (cur->v == x) return (cur->ch[0] ? cur->ch[0]->s : 0) + 1;
if (cur->v > x) return rank(cur->ch[0], x);
return rank(cur->ch[1], x) + (cur->ch[0] ? cur->ch[0]->s : 0) + cur->n;
}
int kth(Node *cur, int k) {
if (!cur || cur->s < k || k <= 0) return 0;
int ls = cur->ch[0] ? cur->ch[0]->s : 0;
if (ls + 1 <= k && k <= ls + cur->n) return cur->v;
if (ls >= k) return kth(cur->ch[0], k);
return kth(cur->ch[1], k - ls - cur->n);
}
int pre(Node *cur, int x) {
if (!cur) return -INF;
if (cur->v >= x) return pre(cur->ch[0], x);
return max(cur->v, pre(cur->ch[1], x));
}
int suf(Node *cur, int x) {
if (!cur) return INF;
if (cur->v <= x) return suf(cur->ch[1], x);
return min(cur->v, suf(cur->ch[0], x));
}
int main() {
int n = read();
while (n--) {
int opt = read(), x = read();
if (opt == 1) insert(root, x);
else if (opt == 2) remove(root, x);
else if (opt == 3) printf("%d\n", rank(root, x));
else if (opt == 4) printf("%d\n", kth(root, x));
else if (opt == 5) printf("%d\n", pre(root, x));
else printf("%d\n", suf(root, x));
}
return 0;
}