二逼平衡树(树套树)

传送门

这道题的做法……我学的是最经典的线段树套平衡树。

因为发现其实这题的题目描述和普通平衡树非常的相似……只是这次是在给定的区间中。所以我们能想象到用线段树维护区间,然后每个线段树的节点都是一颗平衡树,用于维护区间内信息。

具体操作的实现办法:
1.查询k在区间内的排名:在给定的区间的每一个平衡树上求k的排名,其和即为答案。
2.查询区间内排名为k的数:这个操作是不能在线段树上叠加的,所以我们需要二分答案转化为判定类问题,就是转化为问题一。
3.修改某一位置上数值:在给定区间内的所有平衡树上找到这个数并且修改。
4.查询k在区间内前驱:在给定区间所有平衡树内查k的前驱,取最大值。
5.查询k在区间内后继:在给定区间所有平衡树内查k后继,取最小值。

以上操作除了操作2需要二分答案,复杂度是\(O(log^3n)\),剩下的都是\(O(log^2n)\)的。

然后这个具体的实现方法很复杂……其实平衡树内部和线段树内部的操作和普通的方法基本都是大同小异的。不同的在于插入,删除节点以及新建节点。具体的思路其实和平衡树也很像……不过转移到线段树上比较复杂,但是看看代码都能看懂。

注意这题要垃圾回收。还有就是树套树真的好长……还挺容易写错的……这玩意也是咋写都要5k+……

#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<queue>
#include<cstring>
#define rep(i,a,n) for(register int i = a;i <= n;i++)
#define per(i,n,a) for(register int i = n;i >= a;i--)
#define enter putchar('\n')
#define pr pair<int,int>
#define mp make_pair
#define fi first
#define sc second
#define I inline
#define get(x) (t[t[x].fa].ch[1] == (x))
using namespace std;
typedef long long ll;
const int M = 1000005;
const int N = 10000005;
const int INF = 2147483647;

int read()
{
   int ans = 0,op = 1;char ch = getchar();
   while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
   while(ch >='0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
   return ans * op;
}

struct tree
{
   int ch[2],val,size,cnt,fa;
}t[M];

int n,m,a[M],root[M],idx,bin[M],btop,op,l,r,x,y,z;

I int newnode(int x)
{
   int u = btop ? bin[btop--] : ++idx;
   t[u].val = x,t[u].cnt = t[u].size = 1,t[u].fa = t[u].ch[0] = t[u].ch[1] = 0;
   return u;
}

I void update(int u)
{
   t[u].size = t[t[u].ch[0]].size + t[t[u].ch[1]].size + t[u].cnt;
}

I void rotate(int x)
{
   int y = t[x].fa,z = t[y].fa,k = get(x);
   if(z) t[z].ch[get(y)] = x;
   t[x].fa = z,t[y].ch[k] = t[x].ch[k^1],t[t[y].ch[k]].fa = y;
   t[x].ch[k^1] = y,t[y].fa = x;
   update(y),update(x);
}

I void splay(int x)
{
   while(t[x].fa)
   {
      int y = t[x].fa,z = t[y].fa;
      if(z) ((t[y].ch[0] == x) ^ (t[z].ch[0] == y)) ? rotate(x) : rotate(y);
      rotate(x);
   }
   update(x);
}

I int getnum(int k,int x)
{
   int u = root[k],v = 0;
   while(u)
   {
      v = u;
      if(t[u].val == x) return u;
      if(x < t[u].val) u = t[u].ch[0];
      else u = t[u].ch[1];
   }
   return v;
}

I int getkth(int k,int x)
{
   int u = root[k];
   while(u)
   {
      if(x <= t[t[u].ch[0]].size) u = t[u].ch[0];
      else if(x > t[t[u].ch[0]].size + t[u].cnt) x -= (t[t[u].ch[0]].size + t[u].cnt),u = t[u].ch[1];
      else return u;
   }
   return 0;
}

I int getless(int k,int x)
{
   int u = root[k],cur = 0;
   while(u)
   {
      if(t[u].val < x) cur += t[u].cnt + t[t[u].ch[0]].size,u = t[u].ch[1];
      else u = t[u].ch[0];
   }
   return cur;
}

I int getmax(int u)
{
   while(t[u].ch[1]) u = t[u].ch[1];
   return u;           
}

I int getmin(int u)
{
   while(t[u].ch[0]) u = t[u].ch[0];
   return u;
}

I int getpre(int k,int x)
{
   int u = getnum(k,x);
   if(!u) return -INF;
   splay(u),root[k] = u;
   if(t[u].val >= x) u = getmax(t[u].ch[0]);
   return u ? t[u].val : -INF;
}

I int getnext(int k,int x)
{
   int u = getnum(k,x);
   if(!u) return INF;
   splay(u),root[k] = u;
   if(t[u].val <= x) u = getmin(t[u].ch[1]);
   return u ? t[u].val : INF;
}

I void insert(int k,int x)
{
   int u = getnum(k,x);
   if(t[u].val == x)
   {
      splay(u),root[k] = u;
      t[u].cnt++,t[u].size++;
      return;
   }
   u = newnode(x);
   if(!root[k]) {root[k] = u;return;}
   int v = root[k],w = 0,dir = 0;
   while(v)
   {
      w = v;
      if(t[u].val <= t[v].val) dir = 1,v = t[v].ch[0];
      else dir = 0,v = t[v].ch[1];
   }
   if(dir) t[w].ch[0] = u;
   else t[w].ch[1] = u;
   t[u].fa = w,splay(u),root[k] = u;
}

I void del(int k,int x)
{
   int u = getnum(k,x);
   splay(u),root[k] = u;
   if(t[u].cnt > 1) {t[u].cnt--,t[u].size--;return;}
   if(t[u].size == 1) root[k] = 0;
   else if(!t[u].ch[0] || !t[u].ch[1])
   {
      root[k] = t[u].ch[0] | t[u].ch[1];
      t[root[k]].fa = 0;
   }
   else
   {
      t[t[u].ch[0]].fa = 0;
      int v = getmax(t[u].ch[0]);
      splay(v),root[k] = v;
      t[v].ch[1] = t[u].ch[1],t[t[u].ch[1]].fa = v,update(v);
   }
   bin[++btop] = u;
}

void segbuild(int p,int l,int r)
{
   rep(i,l,r) insert(p,a[i]);
   if(l == r) return;
   int mid = (l+r) >> 1;
   segbuild(p<<1,l,mid),segbuild(p<<1|1,mid+1,r);
}

void segchange(int p,int l,int r,int k,int val)
{
   del(p,a[k]),insert(p,val);
   if(l == r) return;
   int mid = (l+r) >> 1;
   if(k <= mid) segchange(p<<1,l,mid,k,val);
   else segchange(p<<1|1,mid+1,r,k,val);
}

int segless(int p,int l,int r,int kl,int kr,int val)
{
   if(l == kl && r == kr) return getless(p,val);
   int mid = (l+r) >> 1;
   if(kr <= mid) return segless(p<<1,l,mid,kl,kr,val);
   else if(kl > mid) return segless(p<<1|1,mid+1,r,kl,kr,val);
   else return segless(p<<1,l,mid,kl,mid,val) + segless(p<<1|1,mid+1,r,mid+1,kr,val);
}

int segpre(int p,int l,int r,int kl,int kr,int val)
{
   if(l == kl && r == kr) return getpre(p,val);
   int mid = (l+r) >> 1;
   if(kr <= mid) return segpre(p<<1,l,mid,kl,kr,val);
   else if(kl > mid) return segpre(p<<1|1,mid+1,r,kl,kr,val);
   else return max(segpre(p<<1,l,mid,kl,mid,val),segpre(p<<1|1,mid+1,r,mid+1,kr,val));
}

int segnext(int p,int l,int r,int kl,int kr,int val)
{
   if(l == kl && r == kr) return getnext(p,val);
   int mid = (l+r) >> 1;
   if(kr <= mid) return segnext(p<<1,l,mid,kl,kr,val);
   else if(kl > mid) return segnext(p<<1|1,mid+1,r,kl,kr,val);
   else return min(segnext(p<<1,l,mid,kl,mid,val),segnext(p<<1|1,mid+1,r,mid+1,kr,val));
}

int segkth(int kl,int kr,int k)
{
   int L = 0,R = 100000000;
   while(L < R)
   {
      int mid = (L+R+1) >> 1;
      if(segless(1,1,n,kl,kr,mid) > k-1) R = mid - 1;
      else L = mid;
   }
   return L;
}

int main()
{
   n = read(),m = read(),t[0].val = -1;
   rep(i,1,n) a[i] = read();
   segbuild(1,1,n);
   while(m--)
   {
      op = read();
      if(op == 3) x = read(),y = read(),segchange(1,1,n,x,y),a[x] = y;
      else l = read(),r = read(),x = read();
      if(op == 1) printf("%d\n",segless(1,1,n,l,r,x) + 1);
      if(op == 2) printf("%d\n",segkth(l,r,x));
      if(op == 4) printf("%d\n",segpre(1,1,n,l,r,x));
      if(op == 5) printf("%d\n",segnext(1,1,n,l,r,x));
   }
   return 0;
}

猜你喜欢

转载自www.cnblogs.com/captain1/p/10193468.html
今日推荐