LCA-倍增算法模板+树上前缀和例题

LCA-倍增模板

O( n l o g n nlogn nlogn)预处理,O( l o g n logn logn)查询

//lca 倍增template
void dfs(int cur,int fath)
{
    
    
	if(st[cur]) return ;
	st[cur] = 1;
	dep[cur] = dep[fath]+1;
	fa[cur][0] = fath;
	for(int i = 1; i <= lg[dep[cur];i++)
		fa[cur][i] = fa[fa[cur][i-1]][i-1];
	for(int i = h[cur];i != -1; i = e[i].ne)
	{
    
    
		int v = e[i].to;
		dfs(v,fath);
		/* 树上边权前缀和
		if(v != fath)
		{
			dis[v] = dis[cur] + e[i].w;
			dfs(v,cur);
		}*/
	}	
}

int lca(int a,int b)
{
    
    
	if(dep[a]>dep[b]) swap(a,b);
	while(dep[a] != dep[b])
		b = fa[b][lg[dep[b]-dep[a]]];
	if(a==b) return a;
	for(int k = lg[dep[a]];k>=0;k--)
		if(fa[a][k] != fa[b][k])
			a = fa[a][k],b=fa[b][k];
	return fa[a][0];			
}


int main()
{
    
    
	for(int i = 2; i <= n; i++)
		lg[i] = lg[i>>1]+1;
	...	
}

树上前缀和

s u m i sum_i sumi 表示结点 i i i到根节点的权值总和。
然后:

  • 若是点权 x , y x,y x,y路径上的和为 s u m x + s u m y − s u m l c a − s u m f a l c a sum_x+sum_y-sum_{lca}-sum_{fa_{lca}} sumx+sumysumlcasumfalca
  • 若是边权 x , y x,y x,y路径上的和为 s u m x + s u m y − 2 s u m l c a sum_x+sum_y-2sum_{lca} sumx+sumy2sumlca
    l c a lca lca的求法参见 最近公共祖先。

例题

loj - 0134- Dis 模板题

ACcode:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<algorithm>
#include<stack>
#include<string>
#include<utility>
#include<cmath>
#include<vector>
#include<functional>//使用 greater<int>();
using namespace std;
typedef long long ll;
typedef pair<int,int> pll;
const int INF = 0x3f3f3f3f;
const int N = 1e4+100;

int n,m;
int	h[N],cnt = 0,dis[N];
int dep[N],fa[N][25],st[N];
int lg[N];
struct node
{
    
    
	int to,w,ne;
}e[N<<1];
void init(){
    
    
	for(int i = 2; i <= 1e4+10;i++)
		lg[i] = lg[i/2]+1;
	memset(h,-1,sizeof(h));		
}
void add(int u,int v,int w)
{
    
    
	e[cnt].to = v;
	e[cnt].w = w;
	e[cnt].ne = h[u];
	h[u] = cnt++;
}

void dfs(int cur,int fath)
{
    
    
	if(st[cur]) return ;
	st[cur] = 1;
	dep[cur] = dep[fath] + 1;
	fa[cur][0] = fath;
	for(int i = 1; i <= lg[dep[cur]];i++)
		fa[cur][i] = fa[fa[cur][i-1]][i-1];
	
	for(int i = h[cur];i != -1;i=e[i].ne)
	{
    
    
		int v = e[i].to;
		if(v != fath)
		{
    
    
			dis[v] = dis[cur] + e[i].w;
			dfs(v,cur);
		}
	}
			
	
}

int LCA(int a,int b)
{
    
    
	if(dep[a] > dep[b]) swap(a,b);
	while(dep[a] != dep[b])
		b = fa[b][lg[dep[b]-dep[a]]];
	if(a==b) return a;
	for(int k = lg[dep[a]];k >= 0;k --)
		if(fa[a][k] != fa[b][k])
			a = fa[a][k],b = fa[b][k];
	return fa[a][0];			
}
int main()
{
    
    
	init();
	cin >> n >> m;
	
	int u,v,w;
	int num = n-1;
	while(num--)
	{
    
    
		cin >> u >> v >> w;
		add(u,v,w);
		add(v,u,w);
	}
	dfs(1,0);
	int s,t;
	while(m--)
	{
    
    
		cin >> s >> t;
		int lca = LCA(s,t);
		int res = dis[s]+dis[t] - 2*dis[lca];
		cout << res << endl;
	}
	
	return 0;
}

#2491. 「BJOI2018」求和



/*
	author:nttttt;
	add oil!
*/

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<algorithm>
#include<stack>
#include<string>
#include<utility>
#include<cmath>
#include<vector>
#include<functional>//使用 greater<int>();
using namespace std;
typedef long long ll;
typedef pair<int,int> pll;
const int INF = 0x3f3f3f3f;
const int N = 3e5+100;
int n,m,k;
int	h[N],cnt = 0;
ll dep[N],fa[N][30],st[N],ans[N][51],node[N];
int lg[N];
int mod = 998244353;
struct Node
{
    
    
	int to,ne;
}e[N<<1];
void init(){
    
    
	for(int i = 2; i <= 3e5;i++)
		lg[i] = lg[i/2]+1;
	memset(h,-1,sizeof(h));
	node[0] = 1;
}
void add(int u,int v)
{
    
    
	e[cnt].to = v;
	e[cnt].ne = h[u];
	h[u] = cnt++;
}

void dfs(int cur,int fath)
{
    
    
	if(st[cur]) return ;
	st[cur] = 1;
	dep[cur] = dep[fath]+1;
	fa[cur][0] = fath;
	for(int i = 1; i <= lg[dep[cur]];i++)
		fa[cur][i] = fa[fa[cur][i-1]][i-1];
	
	for(int i = h[cur];i != -1;i=e[i].ne)
	{
    
    
		int v = e[i].to;
		if(v != fath)
		{
    
    
			for(int j = 1; j <= 50; j++) node[j] = node[j-1]*dep[cur]%mod;
			for(int j = 1; j <= 50; j++) ans[v][j] = (node[j] + ans[cur][j])%mod;
			dfs(v,cur);
		}
			
	}
			
	
}

int LCA(int a,int b)
{
    
    
	if(dep[a] > dep[b]) swap(a,b);
	while(dep[a] != dep[b])
		b = fa[b][lg[dep[b]-dep[a]]];
	if(a==b) return a;
	for(int k = lg[dep[a]];k >= 0;k --)
		if(fa[a][k] != fa[b][k])
			a = fa[a][k],b = fa[b][k];
	return fa[a][0];			
}
int main()
{
    
    
	init();
	scanf("%d",&n);
	
	int u,v;
	int num = n-1;
	while(num--)
	{
    
    
		scanf("%d%d",&u,&v);
		add(u,v);
		add(v,u);
	}
	dfs(1,0);
	
	int s,t;
	scanf("%d",&m);
	while(m--)
	{
    
    
		scanf("%d%d%d",&s,&t,&k);
		int lca = LCA(s,t);
		ll res;
		res = (ans[s][k]+ans[t][k] - ans[lca][k] - ans[fa[lca][0]][k])%mod;
		printf("%lld\n",(res%mod+mod)%mod);
	}
	
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_51687628/article/details/116998486