【算法详解 && LCA】最近公共祖先

一、定义

  • LCA(Lowest Common Ancestors),即最近公共祖先,是指在有根树中,找出某两个结点u和v最近的公共祖先。

公共祖先是什么?对于x,y。如果z既是x的祖先也是y的祖先,那么我们就称z是x和y的公共祖先。

在这里插入图片描述

如上图,结点4,6的公共祖先有1、2,
但最近的公共祖先是2,即Lca(4,6) = 2


二、求法

  • 向上标记法

思想:先让u,v中深度大的往上走,直到u,v深度相同,若此时u==v,则已找到。再让u,v一起往上走,直到走到同一个结点
时间复杂度:O(n)

暴力求解法时间复杂度其实也还可以接受,但是如果对于多组样例就不如tarjan了。

那么具体代码如下:
在这里插入图片描述

  • 树上倍增法

I.思想:注意到u,v走到最近公共祖先w之前,u,v所在结点不相同。而到达最近公共祖先w后,再往上走仍是u,v的公共祖先,即u,v走到同一个结点,这具有二分性质。于是可以预处理出一个 2 k 2^k 的表,fa[k][u]表示u往上走 2 k 2^k 步走到的结点,令根结点深度为0,则 2 k 2^k >depth[u]时,令fa[k][u]=-1(不合法情况的处理)

不妨假设depth[u] < depth[v]
①将v往上走d = depth[v] - depth[u]步,此时u,v所在结点深度相同,该过程可用二进制优化。由于d是确定值,将d看成2的次方的和值, d = 2 k 1 + 2 k 2 + . . . + 2 k m d = 2^{k1} + 2^{k2} + ... + 2^{km} ,利用fa数组,如 v = f a [ k 1 ] [ v ] v = fa[k1][v] v = f a [ k 2 ] [ v ] v = fa[k2][v] 进行加速上升
②若此时 u = v u = v ,说明Lca(u,v)已找到
③利用fa数组加速u,v一起往上走到最近公共祖先w的过程。令 d = d e p t h [ u ] d e p t h [ w ] d = depth[u] - depth[w] ,虽然d是个未知值,但依然可以看成2的次方的和。从高位到低位枚举d的二进制位,设最低位为第0位,若枚举到第k位,有 f a [ k ] [ u ] ! = f a [ k ] [ v ] fa[k][u] != fa[k][v] ,则令 u = f a [ k ] [ u ] u = fa[k][u] v = f a [ k ] [ v ] v = fa[k][v] 。最后最近公共祖先 w = f a [ 0 ] [ u ] = f a [ 0 ] [ v ] w = fa[0][u] = fa[0][v] ,即u和v的父亲.


II.那么我们接下来想如何预处理?
解法:
k=0时, f a [ k ] [ u ] fa[k][u] 为u在有根树中的父亲,令根结点 f a [ k ] [ r o o t ] = 1 fa[k][root]=-1
k>0时, f a [ k ] [ u ] = f a [ k 1 ] [ f a [ k 1 ] [ u ] ] fa[k][u]=fa[k-1][fa[k-1][u]] 。树的高度最多为 n n ,k是 l o g ( n ) log(n) 级别。


III.复杂度:
预处理O(nlogn)
单次查询O(logn)

那么具体代码如下:

#include<cstdio>
#include<cstring>
#include<string>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<map>
#include<queue>
using namespace std;
int t;
int fa[10001];
int d[30001];
struct node{
    int y,v,Next;
}e[50001];
int Fa[10010][25];
int n,m;
int len=0;
int linkk[10010];
bool vis[20001];
int root;
void insert(int x,int y,int v){
    e[++len].Next=linkk[x];
    linkk[x]=len;
    e[len].v=v;
    e[len].y=y;
}
void dfs(int now,int de){
	if (vis[now]) return;
	vis[now]=1;
    if (d[now]==0&&now!=root) d[now]=de;else d[now]=min(d[now],de);
    for(int i=linkk[now];i;i=e[i].Next){
	    int y=e[i].y;
	    if (y==Fa[now][0]) continue;
	    Fa[y][0]=now;
	    dfs(y,de+1);
	}
}
void find_Fa(){
    for (int j=1;(1<<j)<n;j++)
      for (int i=1;i<=n;i++)
        if (Fa[i][j-1]==-1) Fa[i][j]=-1;
        else Fa[i][j]=Fa[Fa[i][j-1]][j-1];
}
int lca(int u,int v){
    if (d[u]>d[v]) swap(u,v);
    for (int dd=d[v]-d[u],i=0;dd;dd>>=1,i++)
      if (dd&1) v=Fa[v][i];
    if (u==v)return u;
    for (int i=24;i>=0;i--)
	  if (Fa[u][i]!=Fa[v][i]) u=Fa[u][i],v=Fa[v][i];
	 return  Fa[u][0];
}
int main(){
	scanf("%d",&t);
	while (t--){
		int st,ed;
		memset(vis,0,sizeof(vis));
		memset(d,0,sizeof(d));
		len=0;
		memset(linkk,0,sizeof(linkk));
		memset(Fa,0,sizeof(Fa));
		memset(fa,0,sizeof(fa));
		scanf("%d",&n);
	    for (int i=1,x,y;i<n;i++) scanf("%d %d",&x,&y),fa[y]=x,insert(x,y,1),insert(y,x,1);
	    scanf("%d %d",&st,&ed);
		for (int i=1;i<=n;i++) if (!fa[i]){root=i;break;}
		Fa[root][0]=-1;
		dfs(root,0);
		find_Fa();
		printf("%d\n",lca(st,ed));
	}
}

  • 离线tarjan

I.离线与在线的区别:
离线算法就是先把所有询问存起来,一次处理完,最后输出。
而在线算法就是即询问即计算,前面两个算法都是在线算法。

II.思想:
Tarjan算法基于这样一个事实,要找w=Lca(u,v),在dfs遍历完u到遍
历完v的过程中,遍历到v时,u到w路径上除w外结点的子树都遍历
过了,w的子树还未遍历完。如果对于结点u,访问完它的子树后就
把u在并查集中的父亲设为它在树中的父亲,那么访问到v时u在并
查集中的父亲就是Lca(u,v)。

那么具体代码如下:

#include<cstdio>
#include<cstring>
#include<string>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<map>
#include<queue>
using namespace std;
#define mp make_pair
typedef pair < int , int > pii;
int root;
int t;
int fa[40010];
vector < pii > e[80020];
vector < pii > id[80020];
bool vis[40010];
int n,m;
int len=0;
int ans[40010];
int a[40010];
int d[40010];
int getfa(int k){
    return k==fa[k]?k:fa[k]=getfa(fa[k]);
}
void tarjan(int u){
    vis[u]=1;
    for (int i=0;i<e[u].size();i++){
	    int y=e[u][i].first;
	    if (vis[y]) continue;
	    tarjan(y);
	    fa[y]=u;
	}
	for (int i=0;i<id[u].size();i++)
	  if (vis[id[u][i].second])
	    ans[id[u][i].first]=d[u]+d[id[u][i].second]-2*d[getfa(id[u][i].second)];
}
void dfs(int u,int de){
	if (vis[u]) return;
	vis[u]=1;
    d[u]=de;
    for (int i=0;i<e[u].size();i++){
	    int y=e[u][i].first;
	    dfs(y,de+e[u][i].second);
	}
}
int main(){
    scanf("%d",&t);
    while (t--){
    	memset(vis,0,sizeof(vis));
    	memset(ans,0,sizeof(ans));
    	memset(d,0,sizeof(d));
    	memset(a,0,sizeof(a));
    	memset(vis,0,sizeof(vis));
	    scanf("%d %d",&n,&m);
	    for (int i=1;i<=n;i++) e[i].clear();
	    for (int i=1;i<=m;i++) id[i].clear();
	    for (int i=1,x,y,z;i<n;i++) scanf("%d %d %d",&x,&y,&z),a[y]=x,e[x].push_back(mp(y,z)),e[y].push_back(mp(x,z));
	    for (int i=1,x,y;i<=m;i++) scanf("%d %d",&x,&y),id[x].push_back(mp(i,y)),id[y].push_back(mp(i,x));
	    for (int i=1;i<=n;i++) if (!a[i]){root=i;break;}
	    dfs(root,0);
	    memset(vis,0,sizeof(vis));
	    for (int i=1;i<=n;i++) fa[i]=i;
	    tarjan(root);
	    for (int i=1;i<=m;i++)
	      printf("%d\n",ans[i]);
	}
	return 0;
}

具体例题请看我的博客题解

猜你喜欢

转载自blog.csdn.net/huang_ke_hai/article/details/87298150