《算法》第二章部分程序 part 3

▶ 书中第二章部分程序,加上自己补充的代码,包括各种优化的快排

  1 package package01;
  2 
  3 import edu.princeton.cs.algs4.In;
  4 import edu.princeton.cs.algs4.StdOut;
  5 import edu.princeton.cs.algs4.StdRandom;
  6 
  7 public class class01
  8 {
  9     private class01() {}
 10 
 11     public static void sort1(Comparable[] a)    // 基本的快排
 12     {
 13         StdRandom.shuffle(a);                   // 数组 a 随机化
 14         sortKernel1(a, 0, a.length - 1);
 15     }
 16 
 17     private static void sortKernel1(Comparable[] a, int lo, int hi)  // 快排递归内核
 18     {
 19         if (hi <= lo)
 20             return;
 21         int j = partition1(a, lo, hi);
 22         sortKernel1(a, lo, j - 1);              // 用不到 sortKernel1 的返回值
 23         sortKernel1(a, j + 1, hi);
 24     }
 25 
 26     private static int partition1(Comparable[] a, int lo, int hi)   // 快排分组
 27     {
 28         int i = lo, j = hi + 1;
 29         for (Comparable v = a[lo];; exch(a, i, j))      // a[lo] 作为分界值,每次循环前要调整 i 和 j,所以初值要过量
 30         {                                               // 没有 break 退出时调整两个错位的值
 31             for (i++; i < hi && less(a[i], v); i++);    // 寻找大于 a[lo] 的元素
 32             for (j--; j > lo && less(v, a[j]); j--);    // 寻找小于 a[lo] 的元素
 33             if (i >= j)                                 // 指针交叉,说明已经排好了
 34                 break;
 35         }
 36         exch(a, lo, j);                                 // 安置哨兵
 37         return j;                                       // 返回哨兵的位置
 38     }
 39 
 40     public static Comparable select(Comparable[] a, int k)  // 从数组中选出第 k 小的数字
 41     {
 42         StdRandom.shuffle(a);
 43         int lo = 0, hi = a.length - 1;
 44         for (; lo < hi;)
 45         {
 46             int i = partition1(a, lo, hi);      // 用 a[0] 分段,这是函数 partition 返回安置位置的原因
 47             if (i > k)                          // 左段过长
 48                 hi = i - 1;
 49             else if (i < k)                     // 右段过长
 50                 lo = i + 1;
 51             else
 52                 return a[i];                    // 找到了指定的分位数
 53         }
 54         return a[lo];                           // 指针交叉或 k 不合法?
 55     }
 56 
 57     public static void sort2(Comparable[] a)    // 优化版本,三路快排
 58     {
 59         StdRandom.shuffle(a);
 60         sortKernel2(a, 0, a.length - 1);
 61     }
 62 
 63     private static void sortKernel2(Comparable[] a, int lo, int hi) // 三路路快排的递归内核,没有 partition 程序
 64     {
 65         if (hi <= lo)
 66             return;
 67         int pLess = lo, pEqual = lo + 1, pGreater = hi; // 三个指针分别指向小于部分、等于部分、大于部分,且 pEqual 指向当前检查元素
 68         for (Comparable v = a[lo]; pEqual <= pGreater;)
 69         {
 70             int cmp = a[pEqual].compareTo(v);
 71             if (cmp < 0)
 72                 exch(a, pLess++, pEqual++);             // 当前元素较小,放到左边
 73             else if (cmp > 0)
 74                 exch(a, pEqual, pGreater--);            // 当前元素较大,放到右边
 75             else
 76                 pEqual++;                               // 当前元素等于哨兵,仅自增相等部分的指针
 77         }
 78         sortKernel2(a, lo, pLess - 1);
 79         sortKernel2(a, pGreater + 1, hi);
 80     }
 81 
 82     private static final int INSERTION_SORT_CUTOFF = 8; // 数组规模不大于 8 时使用插入排序
 83 
 84     public static void sort3(Comparable[] a)            // 优化版本,三采样 - 插入排序
 85     {
 86         sortKernel3(a, 0, a.length - 1);                // 没有数组随机化
 87     }
 88 
 89     private static void sortKernel3(Comparable[] a, int lo, int hi) // 递归内核
 90     {
 91         if (hi <= lo)
 92             return;
 93         if (hi - lo + 1 <= INSERTION_SORT_CUTOFF)                   // 规模较小,使用插入排序
 94         {
 95             insertionSort(a, lo, hi + 1);   //Insertion.sort(a, lo, hi + 1);,教材源码使用 edu.princeton.cs.algs4.Insertion
 96             return;
 97         }
 98         int j = partition3(a, lo, hi);
 99         sortKernel3(a, lo, j - 1);
100         sortKernel3(a, j + 1, hi);
101     }
102 
103     private static int partition3(Comparable[] a, int lo, int hi)   // 分组
104     {
105         exch(a, median3(a, lo, lo + (hi - lo + 1) / 2, hi), lo);    // 开头、中间和结尾各找 1 元求中位数作为哨兵
106 
107         int i = lo, j = hi + 1;
108         Comparable v = a[lo];               // 注意相比函数 partition1 少了一层循环
109 
110         for (i++; less(a[i], v); i++)
111         {
112             if (i == hi)                    // 低指针到达右端,所有元素都比哨兵小
113             {
114                 exch(a, lo, hi);            // 哨兵放到最后,返回哨兵位置
115                 return hi;
116             }
117         }
118         for (j--; less(v, a[j]); j--)
119         {
120             if (j == lo + 1)                // 高指针到达左端,所有元素都比哨兵大
121                 return lo;                  // 直接返回原哨兵位置
122         }
123         for (; i < j;)                      // 在低指针和高指针之间分组
124         {
125             exch(a, i, j);
126             for (i++; less(a[i], v); i++);
127             for (j--; less(v, a[j]); j--);
128         }
129         exch(a, lo, j);
130         return j;
131     }
132 
133     private static final int MEDIAN_OF_3_CUTOFF = 40;   // 数组规模较小时不采用三采样
134 
135     public static void sort4(Comparable[] a)            // 优化版本,使用
136     {
137         sortKernel4(a, 0, a.length - 1);
138     }
139 
140     private static void sortKernel4(Comparable[] a, int lo, int hi) // 递归内核
141     {
142         int n = hi - lo + 1;
143         if (n <= INSERTION_SORT_CUTOFF)                     // 小规模数组用插入排序
144         {
145             insertionSort(a, lo, hi + 1);
146             return;
147         }
148         else if (n <= MEDIAN_OF_3_CUTOFF)                   // 中规模数组用三采样
149             exch(a, median3(a, lo, lo + n / 2, hi), lo);
150         else                                                // 大规模数组使用 Tukey ninther 采样,相当于 9 点采样
151         {
152             int eps = n / 8, mid = lo + n / 2;
153             int ninther = median3(a,
154                 median3(a, lo, lo + eps, lo + eps + eps),
155                 median3(a, mid - eps, mid, mid + eps),
156                 median3(a, hi - eps - eps, hi - eps, hi)
157             );
158             exch(a, ninther, lo);
159         }
160 
161         int i = lo, p = lo, j = hi, q = hi + 1;         // Bentley-McIlroy 三向分组
162         for (Comparable v = a[lo]; ;)
163         {
164             for (i++; i < hi && less(a[i], v); i++);
165             for (j--; j > lo && less(v, a[j]); j--);
166 
167             if (i == j && eq(a[i], v))                  // 指针交叉,单独处理相等的情况
168                 exch(a, ++p, i);
169             if (i >= j)
170                 break;
171 
172             exch(a, i, j);
173             if (eq(a[i], v))
174                 exch(a, ++p, i);
175             if (eq(a[j], v))
176                 exch(a, --q, j);
177         }
178         i = j + 1;
179         for (int k = lo; k <= p; k++)
180             exch(a, k, j--);
181         for (int k = hi; k >= q; k--)
182             exch(a, k, i++);
183         sortKernel4(a, lo, j);
184         sortKernel4(a, i, hi);
185     }
186 
187     private static void insertionSort(Comparable[] a, int lo, int hi)   // 公用插入排序
188     {
189         for (int i = lo; i < hi; i++)
190         {
191             for (int j = i; j > lo && less(a[j], a[j - 1]); j--)
192                 exch(a, j, j - 1);
193         }
194     }
195 
196     private static int median3(Comparable[] a, int i, int j, int k)     // 计算 3 元素的中位数
197     {
198         return less(a[i], a[j]) ?
199             (less(a[j], a[k]) ? j : less(a[i], a[k]) ? k : i) :
200             (less(a[k], a[j]) ? j : less(a[k], a[i]) ? k : i);
201     }
202 
203     private static boolean less(Comparable v, Comparable w)
204     {
205         if (v == w)                 // 相同引用时避免内存访问
206             return false;
207         return v.compareTo(w) < 0;
208     }
209 
210     private static boolean eq(Comparable v, Comparable w)
211     {
212         if (v == w)                 // 相同引用时避免内存访问
213             return true;
214         return v.compareTo(w) == 0;
215     }
216 
217     private static void exch(Object[] a, int i, int j)
218     {
219         Object swap = a[i];
220         a[i] = a[j];
221         a[j] = swap;
222     }
223 
224     private static void show(Comparable[] a)
225     {
226         for (int i = 0; i < a.length; i++)
227             StdOut.println(a[i]);
228     }
229 
230     public static void main(String[] args)
231     {
232         In in = new In(args[0]);
233         String[] a = in.readAllStrings();
234 
235         class01.sort1(a);
236         //class01.sort2(a);
237         //class01.sort3(a);
238         //class01.sort4(a);
239         show(a);
240         /*
241         for (int i = 0; i < a.length; i++)  // 使用函数 select 进行排序,实际上是逐步挑出 a 中排第 1、第 2、第 3 …… 的元素
242         {
243             String ith = (String)class01.select(a, i);
244             StdOut.println(ith);
245         }
246         */
247     }
248 }

猜你喜欢

转载自www.cnblogs.com/cuancuancuanhao/p/9754199.html