主席树详细讲解

主席树,其实就是可持久化线段树.

可持久化,意思就是支持查询历史版本,而一般可持久化的方法都是只修改不一样的地方,别的照样用前面的那个数据结构.

据说主席树是某主席在考场上时不会用划分树(好像是忘了怎么写),临时发明出来的.

至于主席树的可持久化到底是什么意思,可以看一下这张图:


我们发现,这些线段树里面(只是没标所表示的区间),很多信息跟上一棵线段树是一样的.

所以我们可以这样存这几棵线段树:


WA,好像这样很省空间的样子,还很省时间.

是的,若是每次修改后多存一颗线段树,那么这样修改顶多改变了一条链,也就是log(n)个节点.

假设修改了m次,那么这样存线段树就可以做到O(m*log(n))的时间和空间处理.

扫描二维码关注公众号,回复: 1713119 查看本文章

可是一般好像用不到那么多线段树.

但是我们看看区间第k小这道题就知道它有多么牛逼了.

题目:http://poj.org/problem?id=2104.

这道题我们先考虑建出权值线段树,也就是叶子结点维护的是权值,整棵线段树维护的是一个权值区间内的数的数量的线段树.

那我们建n棵线段树,其中第i棵线段树是用前i个数字建的权值线段树.

那么我们就可以在选择一个区间[l,r]的第k小的时候,我们就可以调用第r棵线段树在权值区间[L,R]的数的数量减去第l-1棵线段树在权值区间[L,R]的数的数量得到区间[l,r]在权值区间[L,R]的数的数量.

那我们就可以利用前缀和很简单的找到第k小了.

所以我们的方案是可行的,只不过时空复杂度过高.

但是我们上面已经讨论过了这种问题的解决方案,就是调用上一颗线段树的信息.

不过这题的数据的权值过大,我们需要先进行离散化的操作.

其实就是把输入的数据排个序,然后用二分查找映射过去.

离散化过程时间复杂度O(nlog(n)),代码如下:

int find(int k){
  int l=1,r=n,mid;
  while (l+1<r){
    mid=l+r>>1;
    if (order[mid]>=k) r=mid;
    else l=mid;
  }
  if (k==order[l]) return l;
  else return r;
}
inline void work(){
  sort(order+1,order+1+n);
  for (int i=1;i<=n;i++)
    num[i]=find(a[i]);
}

接下来是一棵主席树的每个节点的结构体(大家都不存左右端点,我就要存):

struct tree{
  int sum,ls,rs;
}tr[N*50];
int top=0,root[N+1];

然后是新增一棵线段树:

void add_tree(int num,int last,int L,int R){
  tr[++top].sum=tr[last].sum+1;tr[top].ls=tr[last].ls;tr[top].rs=tr[last].rs;
  if (L==R) return;
  int mid=L+R>>1;
  if (num<=mid) tr[top].ls=top+1,add_tree(num,tr[last].ls,L,mid);
  else tr[top].rs=top+1,add_tree(num,tr[last].rs,mid+1,R);
}

之后是查询第k大:

int query(int lk,int rk,int k,int L,int R){
  if (L==R) return L;
  int mid=L+R>>1,ts=tr[tr[rk].ls].sum-tr[tr[lk].ls].sum;
  if (k<=ts) return query(tr[lk].ls,tr[rk].ls,k,L,mid);
  else return query(tr[lk].rs,tr[rk].rs,k-ts,mid+1,R);
}

那么AC代码如下:

#include<bits/stdc++.h>
  using namespace std;
int read(){
  int x=0;
  char c=getchar();
  for (;c<'0'||c>'9';c=getchar());
  for (;c<='9'&&c>='0';c=getchar()) x=x*10+c-'0';
  return x;
}
const int N=100000;
int a[N+1],order[N+1],num[N+1],n,m;
struct tree{
  int sum,ls,rs;
}tr[N*50];
int top=0,root[N+1];
inline void into(){
  scanf("%d%d",&n,&m);
  for (int i=1;i<=n;i++)
    scanf("%d",&a[i]),order[i]=a[i];
}
int find(int k){
  int l=1,r=n,mid;
  while (l+1<r){
    mid=l+r>>1;
    if (order[mid]>=k) r=mid;
    else l=mid;
  }
  if (k==order[l]) return l;
  else return r;
}
void add_tree(int num,int last,int L,int R){
  tr[++top].sum=tr[last].sum+1;tr[top].ls=tr[last].ls;tr[top].rs=tr[last].rs;
  if (L==R) return;
  int mid=L+R>>1;
  if (num<=mid) tr[top].ls=top+1,add_tree(num,tr[last].ls,L,mid);
  else tr[top].rs=top+1,add_tree(num,tr[last].rs,mid+1,R);
}
int query(int lk,int rk,int k,int L,int R){
  if (L==R) return L;
  int mid=L+R>>1,ts=tr[tr[rk].ls].sum-tr[tr[lk].ls].sum;
  if (k<=ts) return query(tr[lk].ls,tr[rk].ls,k,L,mid);
  else return query(tr[lk].rs,tr[rk].rs,k-ts,mid+1,R);
}
inline void work(){
  sort(order+1,order+1+n);
  for (int i=1;i<=n;i++)
    num[i]=find(a[i]);
  for (int i=1;i<=n;i++)
    root[i]=top+1,add_tree(num[i],root[i-1],1,n);      //这里插入要插入num[i]
}
inline void outo(){
  int l,r,k;
  for (int i=1;i<=m;i++){
    scanf("%d%d%d",&l,&r,&k);
    printf("%d\n",order[query(root[l-1],root[r],k,1,n)]);      //这里必须是输出order[query(...)]
  }
}
int main(){
  into();
  work();
  outo();
  return 0;
}

至于为什么要输出order[query(...)].

是因为你存的是离散化过后的,而对应i的数字应该是order[i],也就是排序过后的第i小的数.

猜你喜欢

转载自blog.csdn.net/hzk_cpp/article/details/80536191