1171 距离(tarjan算法离线求解最近公共祖先)

1. 问题描述:

给出 n 个点的一棵树,多次询问两点之间的最短距离。注意:边是无向的。所有节点的编号是 1,2,…,n。

输入格式

第一行为两个整数 n 和 m。n 表示点数,m 表示询问次数;下来 n−1 行,每行三个整数 x,y,k,表示点 x 和点 y 之间存在一条边长度为 k;再接下来 m 行,每行两个整数 x,y,表示询问点 x 到点 y 的最短距离。树中结点编号从 1 到 n。

输出格式

共 m 行,对于每次询问,输出一行询问结果。

数据范围

2 ≤ n ≤ 10 ^ 4,
1 ≤ m ≤ 2 × 10 ^ 4,
0 < k ≤ 100,
1 ≤ x,y ≤ n

输入样例1:

2 2 
1 2 100 
1 2 
2 1

输出样例1:

100
100

输入样例2:

3 2
1 2 10
3 1 15
1 2
3 2

输出样例2:

10
25
来源:https://www.acwing.com/problem/content/description/1173/

2. 思路分析:

计算树中两个节点的最短距离其实有一个比较常用的方法,可以参照下图,我们可以先预处理出dis数组,dis[i]存储的是节点编号为i的节点到根节点的距离,如果求解编号为x和y之间的最短距离其实等于x到根节点的距离与y到根节点的距离之和减去他们的最近公共祖先到根节点的距离的2倍,所以这道题目的本质是最近公共祖先问题。先使用dfs遍历所有节点预处理出所有节点到根节点之间的距离,然后求解所有询问的最近公共祖先。求解最近公共祖先比较常用的方法是基于倍增的思想,这种方法是在线求解两点的最近公共祖先,在线求解是每读入一个询问那么立即求解出答案,而离线求解是读入所有的询问然后再求解所有的答案统一输出的方法,除了在线求解最近公共祖先之外,其实还可以离线求解最近公共祖先,其中有一个比较优美的算法叫做:tarjan算法,可以用来离线求解两个节点的最近公共祖先是,时间复杂度为O(m + n),m为询问的次数,n为节点的数目,每一个节点只会被访问一次。

tarjan算法求LCA主要基于深度优先遍历,可以看成是对向上标记法求LCA的优化,在深度优先遍历的过程中将遍历节点的状态分为三大类:

  • 已经遍历过且回溯完的节点,状态可以标记为2
  • 正在遍历的节点,状态可以标记为1
  • 当前还未搜到的点,状态可以标记为0

例如当前正在遍历编号为12的节点,我们求解12左边区域,例如编号为10的节点的最近公共祖先,可以发现10和12的最近公共祖先为5,对于左边区域的其余元素也是类似的,我们可以由此受到启发,将当前已经遍历过而且已经回溯完的节点合并到当前遍历的根节点中,合并两个集合可以使用并查集来实现,主要借助于一个父节点数组p,这样对于当前遍历的节点x与左边区域的某个节点的最近公共祖先,其实看一下左边区域的节点合并到哪一个集合中,也即当前元素所在集合的代表元素,这里可以使用并查集的find函数实现,具体在实现的时候什么时候将当前的节点合并到父节点上呢?我们可以在当前这个节点回溯的时候合并到父节点所在的集合中。具体实现的步骤:

  • 先存储一下所有的询问,python可以列表来存储,其中q[i]是一个列表,存储当前所有与编号为i的节点的询问,列表中的元素为元组类型,第一个元素为询问的另外一个元素,第二个元素为当前询问的编号,记录编号是为了能够将答案记录到对应的编号中,c++可以使用vector嵌套pair,两种语言中数据结构的作用是一样的,表现形式不一样;
  • 使用dfs预处理出每一个节点到根节点的距离,这个方法其实很好实现,在遍历节点的时候更新根节点到当前节点的距离即可,将结果存储到dis中;
  • tarjan算法的具体实现,基于深度优先遍历(dfs),在遍历的时候将当前节点的状态标记为1,遍历当前节点的邻接点,当遍历完当前子节点next之后将当前子节点next合并到当前遍历的根节点u所在的集合中,当遍历完当前节点的所有子节点之后查询所有与根节点相关的所有询问,使用并查集查找询问的另外一个节点所在的集合,将答案标记在答案的对应位置,并且最后将当前的根节点的状态标记为2,因为当前节点遍历完了而且要即将要回溯到上一层,所以应该将状态标记为2,。

3. 代码如下:

from typing import List


class Solution:
    # 并查集查找x的父节点与路径压缩
    def find(self, x: int, p: List[int]):
        if x != p[x]:
            p[x] = self.find(p[x], p)
        return p[x]

    # dis存储根节点到每一个节点的距离
    def dfs(self, u: int, fa: int, dis: List[int], g: List[List[int]]):
        for next in g[u]:
            if next[0] != fa:
                dis[next[0]] = dis[u] + next[1]
                self.dfs(next[0], u, dis, g)

    # 0表示未访问, 1表示正在遍历, 2表示已经访问并且已经回溯
    # u表示正在遍历的节点编号, dis表示节点到根节点的距离, p表示并查集的父节点列表, sta表示tarjan遍历节点的时候标记节点访问状态的列表, res距离对应访问的答案, query[i]表示与第i个节点的所有访问
    def tarjan(self, u: int, dis: List[int], p: List[int], sta: List[int], res: List[int], query: List[List[tuple]],
               g: List[List[int]]):
        # 标记当前的状态已经被访问
        sta[u] = 1
        for next in g[u]:
            if sta[next[0]] == 0:
                self.tarjan(next[0], dis, p, sta, res, query, g)
                # 合并当前的节点到根节点所在集合中
                p[next[0]] = u
        for x in query[u]:
            # 遍历与当前节点相关的所有访问, x[0]询问的另外一个节点, x[1]表示哪一次询问
            if sta[x[0]] == 2:
                anc = self.find(x[0], p)
                # 画图可以得到这个式子
                res[x[1]] = dis[x[0]] + dis[u] - 2 * dis[anc]
        # 执行到这里说明当前根节点下的所有子节点已经被访问并且已经回溯所以应该将当前的节点标记为2返回到上一层的时候已经回溯了
        sta[u] = 2

    def process(self):
        n, m = map(int, input().split())
        g = [list() for i in range(n + 1)]
        # n - 1条边
        for i in range(n - 1):
            a, b, c = map(int, input().split())
            # 有根无向树, 两个方向标记
            g[a].append((b, c))
            g[b].append((a, c))
        # 询问的相关信息
        query = [list() for i in range(n + 1)]
        for i in range(m):
            x, y = map(int, input().split())
            # 存储与x相关的询问, query存储的是另外一个节点的编号与询问的编号
            query[x].append((y, i))
            query[y].append((x, i))
        dis = [0] * (n + 1)
        # dfs和tarjan算法任意选择一个节点作为根节点即可, 这里选择1号点作为根节点
        self.dfs(1, -1, dis, g)
        # sta为tarjan算法标记节点遍历状态的列表
        res, sta = [0] * m, [0] * (n + 1)
        # 并查集的父节点列表, 辅助tarjan算法
        # 注意在初始化的时候下标从0开始, 这样p[i] = i, i从1开始
        p = [i for i in range(n + 1)]
        # 这里列表的参数比较多需要注意不要传错了, 最好是使用全局列表然后使用self来访问可以避免方法中传递很多参数
        self.tarjan(1, dis, p, sta, res, query, g)
        for i in range(m):
            print(res[i])


if __name__ == "__main__":
    Solution().process()

猜你喜欢

转载自blog.csdn.net/qq_39445165/article/details/121222038