178 第K短路(A*算法优化)

1. 问题描述:

给定一张 N 个点(编号 1,2…N),M 条边的有向图,求从起点 S 到终点 T 的第 K 短路的长度,路径允许重复经过点或边。注意: 每条最短路中至少要包含一条边。

输入格式

第一行包含两个整数 N 和 M。接下来 M 行,每行包含三个整数 A,B 和 L,表示点 A 与点 B 之间存在有向边,且边长为 L。最后一行包含三个整数 S,T 和 K,分别表示起点 S,终点 T 和第 K 短路。

输出格式

输出占一行,包含一个整数,表示第 K 短路的长度,如果第 K 短路不存在,则输出 −1。

数据范围

1 ≤ S,T ≤ N ≤ 1000,
0 ≤ M ≤ 10 ^ 5,
1 ≤ K  ≤ 1000,
1 ≤ L ≤ 100

输入样例:

2 2
1 2 5
2 1 4
1 2 2

输出样例:

扫描二维码关注公众号,回复: 13514500 查看本文章

14
来源:https://www.acwing.com/problem/content/180/

2. 思路分析:

我们需要想一下如何枚举所有的路线,因为求解的是第K短路,所以在搜索的时候需要搜索所有的路径,那么整个搜索的空间是非常大的,所以需要使用比较强有力的搜索方式使得在遍历比较少的状态之后可以搜索到目标状态,对于搜索空间非常大的题目,特别适合于A*算法来解决,A*算法需要使用一个启发函数,因为当前状态到目标状态的估价距离需要小于等于当前状态到目标状态的真实距离所以可以将当前状态到目标状态的最短距离作为估价距离,这样无论当前状态怎么走最终到达目标状态的时候最短距离都小于真实的距离,所以作为估价距离是合适的,我们可以在输入的时候建立一个反向图,使用堆优化版的dijkstra算法求解从终点到其余点的最短距离;第二个问题是如何求解第K短路呢?直观上想,当终点第一次出队的时候那么求解的是起点到终点的最短距离,当终点第二次出队的时候求解的起点到终点的第2短距离....第k次出队的时候求解的是起点到终点的第k短距离,这个结论是否是正确的呢?其实是正确的,可以通过反证法来证明:若dist(2) >= d第2小 > d(v) + f(v),其中f为估价函数,v为第二短路上的某个顶点, 所以与优先队列中弹出来的是第2短路的节点就矛盾了所以终点第k次出队的时候肯定是起点到终点的第k短距离。因为求解的是第k短路,所以在搜索的时候凡是可以遍历到的节点都需要加入到队列中,直到终点被更新了k次之后或者队列为空的时候就停止了。

3. 代码如下:

import heapq
from typing import List


class Solution:
    # 测试数据: 
    # 3 3 
    # 1 2 1
    # 2 3 2
    # 2 3 4
    # 1 3 2
    # A*算法
    def astart(self, dis: List[int], s: int, t: int, k: int, n: int, g: List[List[int]]):
        q = list()
        # 堆中的第一个元素为起点到这个点的距离 + 这个点到终点的估价距离, 第二元素起点到这个点的真实距离, 第三个元素为这个点的下标
        heapq.heappush(q, (dis[s], 0, s))
        # count用来记录每一个点出现的次数用来判断终点是否有k次
        count = [0] * (n + 10)
        while q:
            p = heapq.heappop(q)
            ver, distance = p[2], p[1]
            count[ver] += 1
            if count[t] == k: 
                return distance
            for next in g[ver]:
                if count[next[0]] < k:
                    heapq.heappush(q, (distance + next[1] + dis[next[0]], distance + next[1], next[0]))
        return -1

    # 堆优化版的dijkstra算法, 计算终点到其余点的最短距离, 计算出来的这个记录可以作为这个点到终点的估价距离, 这个作为估价距离那么是合适的
    def dijkstra(self, n: int, s: int, dis: List[int], g: List[List[int]]):
        q = list()
        heapq.heappush(q, (0, s))
        vis = [0] * (n + 10)
        while q:
            p = heapq.heappop(q)
            ver = p[1]
            if vis[ver] == 1: continue
            vis[ver] = 1
            for next in g[ver]:
                if dis[next[0]] > dis[ver] + next[1]:
                    dis[next[0]] = dis[ver] + next[1]
                    heapq.heappush(q, (dis[next[0]], next[0]))

    def process(self):
        n, m = map(int, input().split())
        # g1, g2为正向图和反向图
        g1, g2 = [list() for i in range(n + 10)], [list() for i in range(n + 10)]
        for i in range(m):
            a, b, c = map(int, input().split())
            g1[a].append((b, c))
            g2[b].append((a, c))
        # s, t, k表示起点, 终点和第k短路
        s, t, k = map(int, input().split())
        # 因为起点和终点可能是同一个点所以需要判断一下, 题目中要求至少要存在一条边
        if s == t: k += 1
        INF = 10 ** 15
        dis = [INF] * (n + 10)
        dis[t] = 0
        self.dijkstra(n, t, dis, g2)
        return self.astart(dis, s, t, k, n, g1)


if __name__ == '__main__':
    print(Solution().process())

猜你喜欢

转载自blog.csdn.net/qq_39445165/article/details/121711160
今日推荐