树链剖分基础模板(BZOJ1036[ZJOI2008]树的统计Count)

版权声明:转载什么的无所谓啦,反正注明一下出处就行啦~ https://blog.csdn.net/u013672056/article/details/76220039
摘自XZY的博客

1. 前言

如果给你一棵树,求点u到点v路径上点的权值之和,你可能会说:倍增啊!

那如果出题人:我还要你支持修改某个点的权值!

或者再j一点:我还要你支持修改点u到点v路径上点的权值!

那就得用树链剖分了。

2. 什么是树链剖分

上面那个问题,树上区间修改。

区间修改最常见做法就是线段树了。

那我们怎么用线段树维护一颗。。。普通的树呢?

那就给普通的树的每个节点标个号,然后放线段树里呗。区间维护。

但如果随便标号,那点u到点v路径不一定标号是连续的啊,你线段树维护个j啊

所以我们现在引入一个(堆)姿势:



那么这些姿势有什么用呢?先看一道题吧!

传送门= ̄ω ̄=

我们先来看张图:

图中标在边上了,但也不影响我们学。。。

标号方法是:跑dfs,先给当前节点标号,再给重儿子标号(重儿子和当前节点在一个重链上),然后对重儿子递归,最后给剩下的别的儿子标号(别的儿子不和当前节点在一个重链上,所以新建重链,把新建的重链的顶端节点设为那个”别的儿子“)、递归(图中是先给重(zhong)边标号,再给剩下的边标号)。标号从小到大。

不难发现一条重链上的标号是连续的,比如点1到点14,点2到点12

这意味着在线段树中,它们是在一个连续的区间里的,而不是像随便标号时断断续续的。

这样就很好用线段树处理了。

如果两个节点不在一条重链上呢?比如图中的点11和点10,我们要求它们之间的路径上的点权和

那我们就看,点11所在的重链是11->6->2,点10所在的重链是10(10所在的重链只有一个点,就是10)。所以我们就先求出重链11->6->2上的点权和、重链10上的点权和。这两条重链在线段树上都是一段连续的区间,可以直接log2n求出

这时候我们发现还有4->1的重链没有计算,就把它的点权和计算出来,三个重链的点权和加在一起就得到了答案。

所以我们要记录的是:
1. pos[i] 点i的标号
2. top[i] 点i所在重链的顶端节点
3. siz[i] 以点i为根的子树的大小
4. dep[i] 点i的深度
5. fa[i] 点i的父亲节点

我们先跑一边dfs,算出fa、size、dep

然后再跑一边dfs,根据size[i]找出点i的重儿子,然后算出pos、top。

搞完这些就很easy了,因为一段重链在线段树里是一段连续的区间(这是坠重要的)。

我们在查询/修改从点u到点v的路径时,先找到所在重链的顶端节点(top)深度较深的(因为这样能让u和v同步提升,防止一个提到根节点了,另一个还没提,这时候你就不知道提谁了),注意不能按照u和v的深度来提!比如top较深的点是u,然后就用线段树处理区间[pos[top[u]],pos[u]](因为top[u]的标号一定比u要小),再设置u为fa[top[u]],把u往上提,直至u和v在一条重链上(即top[u]==top[v])。这时候可能u和v之间还有一段距离,此时u和v已经在一条重链上,直接处理它们之间的区间就行了。

然后复杂度就是:O(Nlog_{2} N+Qlog^{2} N)

同时这个复杂度也是一般的树链剖分的复杂度。因为重链个数不会超过log_{2} N个,线段树复杂度是log_{2} N的。网上有证明,我就不做过多赘述

然后不要脸的放上我的代码。。

/*
siz[]数组,用来保存以x为根的子树节点个数
top[]数组,用来保存当前节点的所在链的顶端节点
son[]数组,用来保存重儿子
dep[]数组,用来保存当前节点的深度
fa[]数组,用来保存当前节点的父亲
pos[]数组,用来保存树中每个节点剖分后的新编号
rank[]数组,用来保存当前节点在线段树中的位置
*/
#pragma GCC optimize("O2")
#include<bits/stdc++.h>
#define maxn 50000
#define ls (rt<<1)
#define rs (rt<<1|1)
using namespace std;
 
//init begin
struct TREE
{
    int l,r,sum,max;
};
TREE tr[maxn<<2];
char opt[200];
int pos[maxn],fa[maxn],sz[maxn],w[maxn],deep[maxn],q,n,m,cnt=0,top[maxn],son[maxn],rank[maxn];
vector<int> g[maxn];
//init end
 
//IntervalTree begin
 
void pushup(int rt)
{
    tr[rt].sum=tr[ls].sum+tr[rs].sum;
    tr[rt].max=max(tr[ls].max,tr[rs].max);
    return ;
}
 
void build(int l,int r,int rt)
{
    tr[rt].l=l,tr[rt].r=r;
    if(l==r)
        {tr[rt].sum=tr[rt].max=w[l];return ;}
    int mid=l+(r-l)/2;
    build(l,mid,ls),build(mid+1,r,rs);
    pushup(rt);
}
 
void update(int l,int c,int rt)
{
    if(tr[rt].l==tr[rt].r)
    {
        tr[rt].sum=tr[rt].max=c;
        return ;
    }
    int mid=tr[rt].l+(tr[rt].r-tr[rt].l)/2;
    if(l<=mid) update(l,c,ls);
    else update(l,c,rs);
    pushup(rt);
}
 
int query_max(int l,int r,int rt)
{
    if(l<=tr[rt].l&&tr[rt].r<=r) return tr[rt].max;
    int mid=tr[rt].l+(tr[rt].r-tr[rt].l)/2,ans=INT_MIN;
    if(l<=mid) ans=max(ans,query_max(l,r,ls));
    if(r>mid) ans=max(ans,query_max(l,r,rs));
    return ans;
}
 
int query_sum(int l,int r,int rt)
{
    if(l<=tr[rt].l&&tr[rt].r<=r) return tr[rt].sum;
    int mid=tr[rt].l+(tr[rt].r-tr[rt].l)/2,ans=0;
    if(l<=mid) ans+=query_sum(l,r,ls);
    if(r>mid) ans+=query_sum(l,r,rs);
    return ans;
}
 
//IntervalTree end
 
void dfs1(int x,int fat,int d)
{
    sz[x]=1;deep[x]=d;fa[x]=fat;
    for(int i=0;i<g[x].size();i++)
        if(fa[x]!=g[x][i])
        {
            dfs1(g[x][i],x,d+1);
			sz[x]+=sz[g[x][i]];
			if(son[x]==-1||sz[g[x][i]]>sz[son[x]])
				son[x]=g[x][i];
        }
}
 
void dfs2(int x,int tp)
{
    top[x]=tp,pos[x]=++cnt,rank[pos[x]]=x;
    if(son[x]==-1) return ;
    dfs2(son[x],tp);
    for(int i=0;i<g[x].size();i++)
        if(g[x][i]!=fa[x]&&g[x][i]!=son[x])
            dfs2(g[x][i],g[x][i]);
}
 
 
int lca(int a,int b,int ok)
{
    int ans=ok?INT_MIN:0;
    while(top[a]!=top[b])
    {
        if(deep[top[a]]>deep[top[b]]) swap(a,b);
        if(!ok) ans+=query_sum(pos[top[b]],pos[b],1);
        else ans=max(ans,query_max(pos[top[b]],pos[b],1));
        b=fa[top[b]];
    }
    if(deep[a]>deep[b]) swap(a,b);
    if(!ok) ans+=query_sum(pos[a],pos[b],1);
    else ans=max(ans,query_max(pos[a],pos[b],1));
    return ans;
}
 
int main()
{  
    int u,v;
    memset(son,-1,sizeof(son));
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        scanf("%d%d",&u,&v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs1(1,-1,1),top[1]=1,dfs2(1,1);
    for(int i=1;i<=n;i++)
        scanf("%d",&w[pos[i]]);
    build(1,cnt,1);
    scanf("%d",&q);
    for(int i=1;i<=q;i++)
    {
        scanf("%s%d%d",opt,&u,&v);
        if(opt[1]=='M')printf("%d\n",lca(u,v,1));
        else if(opt[1]=='H') update(pos[u],v,1);
        else if(opt[1]=='S') printf("%d\n",lca(u,v,0));
    }
    return 0;
}

  

  

猜你喜欢

转载自blog.csdn.net/u013672056/article/details/76220039
今日推荐