感觉是一些模板操作,用区间翻转来写比合并前驱,合并后继等等简单多了,但是这题暴露了我还不是真正理解按插入顺序排序的splay树。
无注释版:
#include <bits/stdc++.h>
using namespace std;
const int N=8e4+5;
int n,m,tmp,x,s,t,ncnt,root;
char str[10];
int ch[N][2],fa[N],val[N],cnt[N],size[N],tag[N],pos[N];
int top,sta[N];
inline int chk(int x)
{
return ch[fa[x]][1]==x;
}
inline void pushup(int x)
{
if (!x) return;
size[x]=size[ch[x][0]]+size[ch[x][1]]+cnt[x];
}
inline void rotate(int x)
{
int y=fa[x],z=fa[y],k=chk(x),w=ch[x][k^1];
ch[y][k]=w; fa[w]=y;
ch[z][chk(y)]=x; fa[x]=z;
ch[x][k^1]=y; fa[y]=x;
pushup(y); pushup(x);
}
inline void pushdown(int x)
{
if (!x) return;
if (tag[x])
{
tag[ch[x][0]]^=1;
tag[ch[x][1]]^=1;
swap(ch[x][0],ch[x][1]);
tag[x]=0;
}
}
inline void splay(int x,int goal)
{
int now=x;
while (now)
{
sta[++top]=now;
now=fa[now];
}
while (top)
{
pushdown(sta[top]);
top--;
}
while (fa[x]!=goal)
{
int y=fa[x],z=fa[y];
if (z!=goal)
{
if (chk(x)==chk(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
if (!goal) root=x;
}
inline void insert(int x)
{
ncnt++;
if (root) ch[root][1]=ncnt;
fa[ncnt]=root; val[ncnt]=x;
ch[ncnt][0]=ch[ncnt][1]=0;
cnt[ncnt]=size[ncnt]=1;
splay(ncnt,0);
}
inline int kth(int x)
{
int cur=root;
while (true)
{
pushdown(cur);
if (x<=size[ch[cur][0]]) cur=ch[cur][0];
else if (x>size[ch[cur][0]]+cnt[cur])
{
x-=size[ch[cur][0]]+cnt[cur];
cur=ch[cur][1];
}
else return cur;
}
}
inline void reverse(int l,int r)
{
l=kth(l-1),r=kth(r+1);
splay(l,0); splay(r,l);
int now=ch[ch[root][1]][0];
if (now) tag[now]^=1;
}
int main(){
scanf("%d%d",&n,&m);
insert(-1e8);
for (register int i=1; i<=n; ++i)
{
scanf("%d",&tmp);
pos[tmp]=i+1;
insert(tmp);
}
insert(1e8);
while (m--)
{
scanf("%s",str);
scanf("%d",&x);
if (str[0]=='T')
{
splay(pos[x],0);
s=size[ch[root][0]]+1;
if (s==2) continue;
reverse(2,s);
reverse(3,s);
}
if (str[0]=='B')
{
splay(pos[x],0);
s=size[ch[root][0]]+1;
reverse(s,n+1);
reverse(s,n);
}
if (str[0]=='I')
{
splay(pos[x],0);
s=size[ch[root][0]]+1;
scanf("%d",&t);
if (t==-1)
{
reverse(s-1,s);
}
if (t==0) continue;
if (t==1)
{
reverse(s,s+1);
}
}
if (str[0]=='A')
{
splay(pos[x],0);
printf("%d\n",size[ch[root][0]]-1);
}
if (str[0]=='Q')
{
printf("%d\n",val[kth(x+1)]);
}
}
return 0;
}
注释版:
#include <bits/stdc++.h>
using namespace std;
const int N=8e4+5;
int n,m,x,l,t,ncnt,root;
int a[N],pos[N];
int dep,sta[N];
int ch[N][2],fa[N],val[N],cnt[N],size[N],tag[N];
char str[10];
inline void pushdown(int x)
{
if (!x) return;
if (tag[x])
{
tag[ch[x][0]]^=1;
tag[ch[x][1]]^=1;
tag[x]=0;
swap(ch[x][0],ch[x][1]);
}
}
inline int chk(int x)
{
return ch[fa[x]][1]==x;
}
inline void pushup(int x)
{
if (!x) return;
size[x]=size[ch[x][0]]+size[ch[x][1]]+cnt[x];
}
inline void rotate(int x)
{
int y=fa[x],z=fa[y],k=chk(x),w=ch[x][k^1];
ch[y][k]=w; fa[w]=y;
ch[z][chk(y)]=x; fa[x]=z;
ch[x][k^1]=y; fa[y]=x;
pushup(y); pushup(x);
}
inline void splay(int x,int goal)
{
int f=x;
sta[++dep]=f;
while (f)
{
f=fa[f];
sta[++dep]=f;
}
while (dep) pushdown(sta[dep--]);
while (fa[x]!=goal)
{
int y=fa[x],z=fa[y];
if (z!=goal)
{
if (chk(x)==chk(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
if (!goal) root=x;
}
inline void insert(int x)
{
ncnt++;
if (root) ch[root][1]=ncnt;
ch[ncnt][0]=ch[ncnt][1]=0;
fa[ncnt]=root; val[ncnt]=x;
cnt[ncnt]=size[ncnt]=1;
splay(ncnt,0);
}
inline int kth(int x)
{
int cur=root;
while (true)
{
pushdown(cur);
if (x<=size[ch[cur][0]]) cur=ch[cur][0];
else if (x>size[ch[cur][0]]+cnt[cur])
{
x-=size[ch[cur][0]]+cnt[cur];
cur=ch[cur][1];
}
else return cur;
}
}
inline void work(int l,int r)
{
l=kth(l-1); r=kth(r+1);
splay(l,0); splay(r,l);
int now=ch[ch[root][1]][0];
if (now) tag[now]^=1;
}
int main(){
scanf("%d%d",&n,&m);
for (register int i=1; i<=n; ++i) scanf("%d",&a[i]),pos[a[i]]=i+1;
insert(-1e8);
for (register int i=1; i<=n; ++i) insert(a[i]);
insert(1e8);
while (m--)
{
scanf("%s%d",str,&x);
if (str[0]=='T')
{
splay(pos[x],0);
l=size[ch[root][0]]+1;
if (l==2) continue;
work(2,l);
work(3,l);
}
if (str[0]=='B')
{
splay(pos[x],0);
l=size[ch[root][0]]+1;
if (l==n+1) continue;
work(l,n+1);
work(l,n);
}
if (str[0]=='I')
{
scanf("%d",&t);
splay(pos[x],0);
l=size[ch[root][0]]+1;
if (t==-1) work(l-1,l);
if (t==0) continue;
if (t==1) work(l,l+1);
}
if (str[0]=='A')
{
splay(pos[x],0);
printf("%d\n",size[ch[root][0]]-1);
}
if (str[0]=='Q')
{
printf("%d\n",val[kth(x+1)]);
}
}
return 0;
}