题目链接
做些LCA的算法,还是很提高代码能力的,这道题就是典型的LCA模板,所以用它来练一下我的LCA算法还是很好的。
我们要求的是在一棵树上的任意两点的相互距离,既然在一棵树上,就可以直接调用LCA来解了。
我们先任取一根节点,我取的是1,我们从1开始,以1为总的根来预处理这棵树,先对其进行遍历整棵树,建立相互间的关系,并且向下更新的时候不忘整理出dis[](到根1的距离)。
void dfs(int u, int pre, int deep)
{
root[u][0] = pre;
depth[u] = deep;
for(int i=head[u]; i!=-1; i=edge[i].next)
{
int v=edge[i].to;
ll val=edge[i].val;
if(v == pre) continue;
dis[v] = dis[u] + val;
dfs(v, u, deep+1);
}
}
接下来就是预处理部分,我们要知道一个节点向上走(1<<i)步能走到哪个节点,所以在这里预处理一下,不然一会在LCA中处理,时间复杂度会很高,而且还不好处理,我这里这样的处理,可以使得时间复杂度降到log2()的形式。
void init(int st)
{
dfs(st, -1, 0);
for(int j=0; (1<<(j+1))<N; j++)
{
for(int i=1; i<=N; i++)
{
if(root[i][j]<0) root[i][j+1] = -1;
else root[i][j+1] = root[root[i][j]][j];
}
}
}
最后就是LCA这块了,我们既然知道了每个节点到根节点的深度,不如先讲它们放置到等深的高度,然后再逐一向上查询最早遇到的相等共有根节点就是了。
int LCA(int u, int v)
{
if(depth[u] > depth[v]) swap(u, v);
int temp = depth[v] - depth[u];
for(int i=0; (1<<i)<=temp; i++)
{
if( (1<<i)&temp ) v = root[v][i];
}
if(u == v) return u;
for(int i=log2(1.*N); i>=0; i--)
{
if(root[u][i] != root[v][i])
{
u = root[u][i];
v = root[v][i];
}
}
return root[u][0];
}
最后,我们返回的是等根前的那个子节点,所以我们要向上一步,得到等根,这就是用二分不断逼近的原理。
#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <limits>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <map>
#define lowbit(x) ( x&(-x) )
#define INF 0x3f3f3f3f
#define pi 3.141592653589793
#define e 2.718281828459045
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int maxN = 40005;
int N, M, head[maxN], cnt, root[maxN][17], depth[maxN];
ll dis[maxN];
struct Eddge
{
int next, to;
ll val;
Eddge(int a=-1, int b=0, ll c=0):next(a), to(b), val(c) {}
}edge[maxN<<1];
void addEddge(int u, int v, ll val)
{
edge[cnt] = Eddge(head[u], v, val);
head[u] = cnt++;
}
void dfs(int u, int pre, int deep)
{
root[u][0] = pre;
depth[u] = deep;
for(int i=head[u]; i!=-1; i=edge[i].next)
{
int v=edge[i].to;
ll val=edge[i].val;
if(v == pre) continue;
dis[v] = dis[u] + val;
dfs(v, u, deep+1);
}
}
void init(int st)
{
dfs(st, -1, 0);
for(int j=0; (1<<(j+1))<N; j++)
{
for(int i=1; i<=N; i++)
{
if(root[i][j]<0) root[i][j+1] = -1;
else root[i][j+1] = root[root[i][j]][j];
}
}
}
int LCA(int u, int v)
{
if(depth[u] > depth[v]) swap(u, v);
int temp = depth[v] - depth[u];
for(int i=0; (1<<i)<=temp; i++)
{
if( (1<<i)&temp ) v = root[v][i];
}
if(u == v) return u;
for(int i=log2(1.*N); i>=0; i--)
{
if(root[u][i] != root[v][i])
{
u = root[u][i];
v = root[v][i];
}
}
return root[u][0];
}
int main()
{
int T; scanf("%d", &T);
while(T--)
{
scanf("%d%d", &N, &M);
cnt = 0;
memset(root, -1, sizeof(root));
memset(dis, 0, sizeof(dis));
memset(head, -1, sizeof(head));
for(int i=1; i<N; i++)
{
int e1, e2;
ll e3;
scanf("%d%d%lld", &e1, &e2, &e3);
addEddge(e1, e2, e3);
addEddge(e2, e1, e3);
}
init(1);
while(M--)
{
int L, R;
scanf("%d%d", &L, &R);
printf("%lld\n", dis[L] + dis[R] - 2*dis[LCA(L, R)]);
}
}
return 0;
}