倍增 || ST-RMQ 求LCA

视频讲解戳这里(bj聚聚讲的非常好,一听就懂)

由luogu一道模板题讲起 传送门 (都9102年,这题还卡输入输出,要吸氧才能过)

LCA:求两个点的最近公共祖先

倍增求LCA:

有很多博客图文并茂讲的非常清楚,倍增LCA的主要思想就是,先将u,v(保证u的深度更深)两点调整到同一水平线

情况一:u==v,那就说明最近公共祖先是u,这种情况说明u,v在同一边,祖先自然是深度较浅的那一个    

情况而:u,v每次往上跳2的j次方(j从小到大),如果二者所跳到的点相同,那就不跳,反之则跳,最后就落在了最近公共祖先的下一个位置

代码结合视频看起来更香噢:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<string>
#include<cstring>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<cmath>
using namespace std;
typedef long long ll;
const ll inf=0x3f3f3f3f;
const int maxn=5e5+5;
const int maxbit=20;
int n,m,s;
int dep[maxn],fa[maxn][maxbit],lg[maxn];
//记录每个点的深度;fa[i][j],i点往上跳2^J的父亲节点;预处理log2,向下取整
vector<int> G[maxn];
void dfs(int np,int fat){ //当前节点,父亲节点
	dep[np]=dep[fat]+1;
	fa[np][0]=fat;
	for(int j=1;j<=lg[dep[np]]+1;++j){
		fa[np][j]=fa[fa[np][j-1]][j-1]; 
//np往上跳2^j的节点相当于先往上跳到2^(j-1)节点处再往上跳2^(j-1)
	}
	for(int i=0;i<(int)G[np].size();++i){
		if(G[np][i]!=fat){
			dfs(G[np][i],np);
		}
	}
}
int lca(int u,int v){
	if(dep[u]<dep[v])	swap(u,v);
	while(dep[u]!=dep[v]){ //调整平衡
		u=fa[u][lg[dep[u]-dep[v]]];
	}
	if(u==v)	return u; //情况一
	for(int i=lg[dep[u]];i>=0;--i){ //情况二
		if(fa[u][i]!=fa[v][i]){
			u=fa[u][i];
			v=fa[v][i];
		}
	}
	return fa[u][0];
}
int main(){
	scanf("%d%d%d",&n,&m,&s);
	int x,y;
	lg[0]=-1;
	for(int i=1;i<maxn;++i){
		lg[i]=lg[i>>1]+1;
	}
	for(int i=1;i<=n-1;++i){
		scanf("%d%d",&x,&y);
		G[x].push_back(y);
		G[y].push_back(x); 
	}
	dfs(s,0);
	for(int i=1;i<=m;++i){
		scanf("%d%d",&x,&y);
		printf("%d\n",lca(x,y));
	}
	return 0;
}

纯净版:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<string>
#include<cstring>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<cmath>
using namespace std;
typedef long long ll;
const ll inf=0x3f3f3f3f;
const int maxn=5e5+5;
const int maxbit=20;
int n,m,s;
int dep[maxn],fa[maxn][maxbit],lg[maxn];
vector<int> G[maxn];
void dfs(int np,int fat){
	dep[np]=dep[fat]+1;
	fa[np][0]=fat;
	for(int j=1;j<=lg[dep[np]]+1;++j){
		fa[np][j]=fa[fa[np][j-1]][j-1];
	}
	for(int i=0;i<(int)G[np].size();++i){
		if(G[np][i]!=fat){
			dfs(G[np][i],np);
		}
	}
}
int lca(int u,int v){
	if(dep[u]<dep[v])	swap(u,v);
	while(dep[u]!=dep[v]){
		u=fa[u][lg[dep[u]-dep[v]]];
	}
	if(u==v)	return u;
	for(int i=lg[dep[u]];i>=0;--i){
		if(fa[u][i]!=fa[v][i]){
			u=fa[u][i];
			v=fa[v][i];
		}
	}
	return fa[u][0];
}
int main(){
	scanf("%d%d%d",&n,&m,&s);
	int x,y;
	lg[0]=-1;
	for(int i=1;i<maxn;++i){
		lg[i]=lg[i>>1]+1;
	}
	for(int i=1;i<=n-1;++i){
		scanf("%d%d",&x,&y);
		G[x].push_back(y);
		G[y].push_back(x); 
	}
	dfs(s,0);
	for(int i=1;i<=m;++i){
		scanf("%d%d",&x,&y);
		printf("%d\n",lca(x,y));
	}
	return 0;
}

 ST-RMQ 求LCA:

先求一个dfs序列,那求u,v的lca就是u,v在dfs序中最早出现的位置之间深度最小的那一个,st[i][j]代表i~i+2^j区间中深度最小在dfs序中的下表。

代码结合视频看起来更香噢:

扫描二维码关注公众号,回复: 5094126 查看本文章
#include<bits/stdc++.h>
using namespace std;
const int maxn=5e5+5;
const int maxbit=20;
vector<int> G[maxn];
int order[maxn<<2],depth[maxn<<2];
//深搜序列    //深度 
int lg[maxn<<2],st[maxn<<2][maxbit];
//预处理log   //st[i][j]代表i~i+2^j区间中深度最小编号 
int first_place[maxn];//dfs序中i最早出现的下标 
int n,m,s,cnt=0; 
inline int read(){
	char ch=getchar();
	int x=0,f=1;
	while((ch>'9'||ch<'0')&&ch!='-'){
		ch=getchar();
	}
	if(ch=='-'){
		f=-1;
		ch=getchar();
	}
	while('0'<=ch&&ch<='9'){
		x=x*10+ch-'0';
		ch=getchar();
	}
	return x*f;
}
void dfs(int np,int dep){
	++cnt;
	first_place[np]=cnt;
	order[cnt]=np;
	depth[cnt]=dep+1;
	for(int i=0;i<(int)G[np].size();++i){
		int to=G[np][i];
		if(first_place[to]==0){
			dfs(to,dep+1);
			++cnt;
			order[cnt]=np;
			depth[cnt]=dep+1;
		}
	}
}
void STinit(){
	for(int i=1;i<=cnt;++i){
		st[i][0]=i;
	}
	int a,b;
	for(int j=1;j<=lg[cnt];++j){
		for(int i=1;i+(1<<j)-1<=cnt;++i){
			a=st[i][j-1];
			b=st[i+(1<<(j-1))][j-1];
			if(depth[a]<depth[b])
				st[i][j]=a;
			else	st[i][j]=b;
		}
	}
}
int main(){
	n=read(),m=read(),s=read();
	lg[0]=-1;
	for(int i=1;i<maxn*2;++i)	lg[i]=lg[i>>1]+1;
	int x,y;
	for(int i=1;i<n;++i){
		x=read(),y=read();
		G[x].push_back(y);
		G[y].push_back(x);
	}
	dfs(s,0);
	STinit();
	for(int i=1;i<=m;++i){
		x=read(),y=read();
		x=first_place[x];
		y=first_place[y];
		if(x>y)	swap(x,y);
		int k=lg[y-x];
		int a=st[x][k],b=st[y-(1<<k)+1][k];
		if(depth[a]<depth[b])
			printf("%d\n",order[a]);
		else	printf("%d\n",order[b]);
		
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/TDD_Master/article/details/86625725
今日推荐