初学主席树,主要是反复看了卿学姐的视频(我竟然在B站学算法)和知乎“主席树是如何求区间k大的”,才算懂了点皮毛。
传送门:
首先,学习主席树要点的前置技能是权值线段树(卿学姐说的是线段树,个人认为不太确切)。权值线段树之所以会带上“权值”二字,是因为它是记录权值的线段树。因此需要用到离散化操作来处理a[1-n]。记录权值指的是,每个点上存的是区间内的数字出现的总次数。比如一个长度为10的数组[1,1,2,3,3,4,4,4,4,5,5]。
其中1出现了两次,那么[1,1]这个节点的值为2,2出现了1次,那么[2,2]这个节点的值为1,那么显然[1,2]这个节点的值为3,即1出现的次数和2出现的次数加和。那么如果我想要知道这个数组上的第k小,我就可以在这棵权值线段树上用logn的时间来实现。比如我想要求这个区间上的第7小,那么我先找到这棵树的根节点,根节点上的数字显示的是10,表示在[1,8]这个区间上一共有10个数字,那么我只要去看它的左孩子上的个数是多少。这时我看到左孩子上的数字是9,说明前9小的数字都在左子树上,那么我要找的第7小也在左子树上,那么我就递归去找左子树。当我再看左孩子的时候,看到数字是3,说明前3小的数字在左子树上,那么我要找的就是右子树上的第k-sum[i]小,即7-3=4,找到右子树上的第4小即可。直到找到某一个叶子节点,说明找到了我要找的第k小。这是通过权值线段树找到区间[1,n]上的第k小/大的应用。
那么知道了权值线段树是什么之后,主席树又是什么呢。主席树是一棵可持久化线段树,可持久化指的是它保存了这棵树的所有历史版本,最简单的办法是:如果你输入了n个数,那么每输入一个数字a[i],就构造一棵保存了从a[1]到a[i]的权值线段树。之所以这么做,是因为我们可以把第j棵树和第(i-1)棵树上的每个点的权值相减,来得到一颗新的权值线段树,而这个新的权值线段树相当于是输入了a[i]到a[j]以后得到的。如果这么说不太好理解的话,我们可以思考另外一个模型:求数组a[1]到a[n]的和。如果只是求[1,n]这一段的和,那么我们直接全部加起来就可以了,或者求一个前缀和sum[n]即可。那么如果我给定了l和r,想要知道[l,r]这段区间上的和呢?是不是利用前缀和sum[r]-sum[l-1]就可以轻松得到?那么主席树的思想也是如此,将tree[r]-tree[l-1]得到的一棵权值线段树即为属于[l,r]的一棵权值线段树,那么在这么一棵权值线段树上求第k大不是就转变为之前的问题了么。如果还是没有理解为什么可以用tree[r]-tree[l-1]来表示属于[l,r]的权值线段树,可以自己构造一个数组,然后画出属于[1,l-1],[1,r]和[l,r]的三颗权值线段树,来自己研究研究,多自己动手也不是一件坏事嘛。
还有一个问题需要解决,那就是空间问题。显而易见的是,如果每输入一个数就重新构造一棵权值线段树,必然会导致空间不够用:一棵线段树的空间就是n*4,那么一共的空间开销就是n*n*4,显然是会MLE的。那么这个问题怎么解决呢?可以发现每更新一个点,就会从它开始把它的所有祖先都更新一次,而其他的点都没有被改变,即:每次改变的结点只有logn个。这样,我们每次输入一个数,只需要多开logn个空间,那么实际的空间开销只有n*(4+logn),满足了空间要求。
以两道基础题结束。(代码主要仿照卿学姐的视频中的代码)
AC代码:
/********************************************** Author: StupidTurtle Date: 2018.5.25 Email: [email protected] **********************************************/ #include <cstdio> #include <vector> #include <cstring> #include <iostream> #include <algorithm> using namespace std; typedef long long ll ; const int oo = 0x7f7f7f7f ; const int maxn = 1e5 + 7 ; const int mod = 1e9 + 7 ; int n , m , cnt , root[maxn] , a[maxn] , x , y , k ; struct node { int l , r , sum ; }T[maxn*25]; vector<int> v ; int getid ( int x ){ return lower_bound(v.begin(),v.end(),x)-v.begin()+1 ; } void update ( int l , int r , int &x , int y , int pos ){ T[++cnt]=T[y] , T[cnt].sum ++ , x = cnt ; if ( l == r ) return ; int mid = ( l + r ) / 2 ; if ( mid >= pos ) update ( l , mid , T[x].l , T[y].l , pos ); else update ( mid + 1 , r , T[x].r , T[y].r , pos ); } int query ( int l , int r , int x , int y , int k ){ if ( l == r ) return l ; int mid = ( l + r ) / 2 ; int sum = T[T[y].l].sum - T[T[x].l].sum ; if ( sum >= k ) return query ( l , mid , T[x].l , T[y].l , k ); else return query ( mid + 1 , r , T[x].r , T[y].r , k - sum ); } int main(void){ scanf("%d%d",&n,&m ); for ( int i = 1 ; i <= n ; i ++ ) scanf("%d",&a[i] ),v.push_back(a[i]); sort(v.begin(),v.end()); v.erase(unique(v.begin(),v.end()),v.end()); for ( int i = 1 ; i <= n ; i ++ ) update ( 1 , n , root[i] , root[i-1] , getid(a[i]) ); for ( int i = 1 ; i <= m ; i ++ ){ scanf("%d%d%d",&x,&y,&k ); printf("%d\n",v[query(1,n,root[x-1],root[y],k)-1] ); } return 0 ; }
AC代码:
/********************************************** Author: StupidTurtle Date: 2018.5.25 Email: [email protected] **********************************************/ #include <cstdio> #include <vector> #include <cstring> #include <iostream> #include <algorithm> using namespace std; typedef long long ll ; const int oo = 0x7f7f7f7f ; const int maxn = 1e5 + 7 ; const int mod = 1e9 + 7 ; int t , n , m , cnt , root[maxn] , a[maxn] , x , y , k ; struct node { int l , r , sum ; }T[maxn*25]; vector<int> v ; int getid ( int x ){ return lower_bound(v.begin(),v.end(),x)-v.begin()+1 ; } void update ( int l , int r , int &x , int y , int pos ){ T[++cnt]=T[y] , T[cnt].sum ++ , x = cnt ; if ( l == r ) return ; int mid = ( l + r ) / 2 ; if ( mid >= pos ) update ( l , mid , T[x].l , T[y].l , pos ); else update ( mid + 1 , r , T[x].r , T[y].r , pos ); } int query ( int l , int r , int x , int y , int k ){ if ( l == r ) return l ; int mid = ( l + r ) / 2 ; int sum = T[T[y].l].sum - T[T[x].l].sum ; if ( sum >= k ) return query ( l , mid , T[x].l , T[y].l , k ); else return query ( mid + 1 , r , T[x].r , T[y].r , k - sum ); } int main(void){ scanf("%d",&t ); while ( t -- ){ v.clear(); cnt = 0 ; scanf("%d%d",&n,&m ); for ( int i = 1 ; i <= n ; i ++ ) scanf("%d",&a[i] ),v.push_back(a[i]); sort(v.begin(),v.end()); v.erase(unique(v.begin(),v.end()),v.end()); for ( int i = 1 ; i <= n ; i ++ ) update ( 1 , n , root[i] , root[i-1] , getid(a[i]) ); for ( int i = 1 ; i <= m ; i ++ ){ scanf("%d%d%d",&x,&y,&k ); printf("%d\n",v[query(1,n,root[x-1],root[y],k)-1] ); } } return 0 ; }
之所以把hdu-2665也放上来,是因为在这道题的初始化上踩了坑,只记得把vector清空,忘记把作为存储初始值的cnt给赋为0,也算是一个提醒吧。顺带一提,hdu-2665的题意说的是求区间第k大,但是区间第k大wa了,而区间第k小AC了,很奇怪2333。