▶ 书中第四章部分程序,包括在加上自己补充的代码,两种 Prim 算法求最小生成树
● 简单 Prim 算法求最小生成树
1 package package01; 2 3 import edu.princeton.cs.algs4.In; 4 import edu.princeton.cs.algs4.StdOut; 5 import edu.princeton.cs.algs4.Edge; 6 import edu.princeton.cs.algs4.EdgeWeightedGraph; 7 import edu.princeton.cs.algs4.Queue; 8 import edu.princeton.cs.algs4.MinPQ; 9 10 public class class01 11 { 12 private static final double FLOATING_POINT_EPSILON = 1E-12; 13 14 private boolean[] marked; // 顶点是否在生成树中 15 private double weight; // 生成树的权值和 16 private Queue<Edge> mst; // 生成树包含的边 17 private MinPQ<Edge> pq; // 搜索队列 18 19 public class01(EdgeWeightedGraph G) 20 { 21 marked = new boolean[G.V()]; 22 mst = new Queue<Edge>(); 23 pq = new MinPQ<Edge>(); 24 for (int v = 0; v < G.V(); v++) // 对每个没有遍历的节点都使用 prim 25 { 26 if (!marked[v]) 27 prim(G, v); 28 } 29 } 30 31 private void prim(EdgeWeightedGraph G, int s) 32 { 33 for (scan(G, s); !pq.isEmpty();) 34 { 35 Edge e = pq.delMin(); // 取出权值最小的边 36 int v = e.either(), w = e.other(v); 37 if (marked[v] && marked[w]) // 若该边两端都遍历过,不要(由于 scan,v 与 w 之一肯定被遍历过) 38 continue; 39 mst.enqueue(e); // 将权值最小的边加入生成树 40 weight += e.weight(); // 更新权值和 41 if (!marked[v]) // 从新边的新顶点继续收集新的边 42 scan(G, v); 43 if (!marked[w]) 44 scan(G, w); 45 } 46 } 47 48 private void scan(EdgeWeightedGraph G, int v) // 将一端为 v、另一端没有遍历过的边放入队列中 49 { 50 marked[v] = true; 51 for (Edge e : G.adj(v)) 52 { 53 if (!marked[e.other(v)]) 54 pq.insert(e); 55 } 56 } 57 58 public Iterable<Edge> edges() 59 { 60 return mst; 61 } 62 63 public double weight() 64 { 65 return weight; 66 } 67 68 public static void main(String[] args) 69 { 70 In in = new In(args[0]); 71 EdgeWeightedGraph G = new EdgeWeightedGraph(in); 72 class01 mst = new class01(G); 73 for (Edge e : mst.edges()) 74 StdOut.println(e); 75 StdOut.printf("%.5f\n", mst.weight()); 76 } 77 }
● 改进,使用索引最小优先队列来建立搜索队列,记录(起点到)每个顶点的距离来判断是否将新边加入生成树
1 package package01; 2 3 import edu.princeton.cs.algs4.In; 4 import edu.princeton.cs.algs4.StdOut; 5 import edu.princeton.cs.algs4.Edge; 6 import edu.princeton.cs.algs4.EdgeWeightedGraph; 7 import edu.princeton.cs.algs4.Queue; 8 import edu.princeton.cs.algs4.IndexMinPQ; 9 10 public class class01 11 { 12 private static final double FLOATING_POINT_EPSILON = 1E-12; 13 14 private boolean[] marked; 15 private Edge[] edgeTo; // 除了搜索起始顶点,新加入每条边对应一个顶点,顶点 v 对应的边是 edgeTo[v] 16 private double[] distTo; // 生成树到每个顶点的距离,用于衡量新边是否值得加入生成树 17 private IndexMinPQ<Double> pq;// 搜索队列 18 19 public class01(EdgeWeightedGraph G) 20 { 21 marked = new boolean[G.V()]; 22 edgeTo = new Edge[G.V()]; 23 distTo = new double[G.V()]; 24 pq = new IndexMinPQ<Double>(G.V()); 25 for (int v = 0; v < G.V(); v++) 26 distTo[v] = Double.POSITIVE_INFINITY; 27 for (int v = 0; v < G.V(); v++) 28 { 29 if (!marked[v]) 30 prim(G, v); 31 } 32 } 33 34 private void prim(EdgeWeightedGraph G, int s) 35 { 36 distTo[s] = 0.0; // 搜索起点对应的距离为 0 37 for (pq.insert(s, distTo[s]); !pq.isEmpty();) 38 { 39 int v = pq.delMin(); // 每次取距离最小的顶点来开花(防止同一个顶点可以有对多条边连到树上) 40 scan(G, v); // 注意 scan 只负责在给定的顶点上开花,不负责递归 41 } 42 } 43 44 private void scan(EdgeWeightedGraph G, int v) 45 { 46 marked[v] = true; 47 for (Edge e : G.adj(v)) 48 { 49 int w = e.other(v); 50 if (marked[w]) // 边 v-w 两端都被遍历过,在队列中 51 continue; 52 if (e.weight() < distTo[w]) // 边 v-w 的权值小于顶点 w 的距离,说明加入该条边后生成树的总权值会下降 53 { 54 distTo[w] = e.weight(); // 加入边 v-w,更新 distTo 和 edgeTo 55 edgeTo[w] = e; 56 if (pq.contains(w)) // 搜若索队列中已经存在 w 则更新其 distTo(键值),不存在则将 w 加入 57 pq.decreaseKey(w, distTo[w]); 58 else 59 pq.insert(w, distTo[w]); 60 } 61 } 62 } 63 64 public Iterable<Edge> edges() 65 { 66 Queue<Edge> mst = new Queue<Edge>(); 67 for (int v = 0; v < edgeTo.length; v++)// 遍历边列表,把每个顶点对应的边加入队列中 68 { 69 Edge e = edgeTo[v]; 70 if (e != null) 71 mst.enqueue(e); 72 } 73 return mst; 74 } 75 76 public double weight() 77 { 78 double weight = 0.0; 79 for (Edge e : edges()) 80 weight += e.weight(); 81 return weight; 82 } 83 84 public static void main(String[] args) 85 { 86 In in = new In(args[0]); 87 EdgeWeightedGraph G = new EdgeWeightedGraph(in); 88 class01 mst = new class01(G); 89 for (Edge e : mst.edges()) 90 StdOut.println(e); 91 StdOut.printf("%.5f\n", mst.weight()); 92 } 93 }