【jdk8源码】TimSort算法——从头看到脚

      首先,在Java 6中Arrays.sort()和Collections.sort()使用的是MergeSort,而在Java7以后,内部实现换成了TimSort。我们通过看jdk8的Collections.sort()源码,来了解一下TimSort算法

简介

      Timsort是一个自适应的、混合的、稳定的排序算法,融合了归并算法和二分插入排序算法的精髓,在现实世界的数据中有着特别优秀的表现。它是由Tim Peter于2002年发明的,用在Python这个编程语言里面。这个算法之所以快,是因为它充分利用了现实世界的待排序数据里面,有很多子串是已经排好序的不需要再重新排序,利用这个特性并且加上合适的合并规则可以更加高效的排序剩下的待排序序列。

      当Timsort运行在部分排序好的数组里面的时候,需要的比较次数要远小于nlogn,也是远小于相同情况下的归并排序算法需要的比较次数。但是和其他的归并排序算法一样,最坏情况下的时间复杂度是O(nlogn)的水平。但是在最坏的情况下,Timsort需要的临时存储空间只有n/2,在最好的情况下,需要的额外空间是常数级别的。从各个方面都能够击败需要O(n)空间和稳定O(nlogn)时间的归并算法。

jdk源码

Collections

public static <T extends Comparable<? super T>> void sort(List<T> list) {
    list.sort(null);
}

List

default void sort(Comparator<? super E> c) {
        Object[] a = this.toArray();
         //在这里真正确定了使用的是TimSort算法
         //默认的Array. sort(int[] a)这里用的是双轴排序,以后再说
        Arrays.sort(a, (Comparator) c);
        ListIterator<E> i = this.listIterator();
        for (Object e : a) {
            i.next();
            i.set((E) e);
        }
    }

Arrays

	public static <T> void sort(T[] a, Comparator<? super T> c) {
        if (c == null) {
        	//这里数组里的元素如果是引用类型必须要实现Comparator<T>接口
        	//并对其排序内部的比较函数compare()进行重写,以便于我们按照我们的排序要求对引用对象数组极性排序,默认是升序排序,但可以自己自定义成降序排序
            sort(a); 
        } else {
            if (LegacyMergeSort.userRequested)
            	//这是兼容1.6之前旧版本,采用的是冒泡排序和归并排序
                legacyMergeSort(a, c);
            else
                TimSort.sort(a, 0, a.length, c, null, 0, 0);
        }
    }
    public static void sort(Object[] a) {
        if (LegacyMergeSort.userRequested)
            legacyMergeSort(a);
        else
        	//这里跟 TimSort.sort的思想是一样的
            ComparableTimSort.sort(a, 0, a.length, null, 0, 0);
    }

TimSort

	static <T> void sort(T[] a, int lo, int hi, Comparator<? super T> c,
                         T[] work, int workBase, int workLen) {
        assert c != null && a != null && lo >= 0 && lo <= hi && hi <= a.length;

        int nRemaining  = hi - lo;
        if (nRemaining < 2)
            return;   // 长度是0或者1 就不需要排序了。

        // 1 如果小于32,就用二分插入排序算法
        if (nRemaining < MIN_MERGE) {
        	// 1.1 先找自然升序序列(如果是倒序,会颠倒为正序排列),返回自然序列大小
            // 这里的自然序列就是数组中从lo以后,已经排好的序列
            int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
            // 1.2 二分插入排序
            binarySort(a, lo, hi, lo + initRunLen, c);
            return;
        }

        // 2 归并排序
        // 新建TimSort对象,保存栈的状态
        TimSort<T> ts = new TimSort<>(a, c, work, workBase, workLen);
        // 2.1 切分数组,返回大小在[16,32)之间的数,作为最小的分割槽大小
        int minRun = minRunLength(nRemaining);
        // 2.2 循环拆分数组,形成大小相差不多的run分割曹
        do {
            // 2.2.1 先找自然升序序列(如果是倒序,会颠倒为正序排列),返回自然序列大小
            int runLen = countRunAndMakeAscending(a, lo, hi, c);

            // 2.2.2 如果自然升序序列小于minRun ,需要按照minRun大小进行拆分并排序
            if (runLen < minRun) {
                int force = nRemaining <= minRun ? nRemaining : minRun;
                //把短的自然升序序列通过二分插入排序
                binarySort(a, lo, lo + force, lo + runLen, c);
                runLen = force;
            }

            // 2.2.3 把已经排好序的数列压入栈中,检查是不是需要合并
            ts.pushRun(lo, runLen);
            // 2.2.3 检查是不是需要合并
            ts.mergeCollapse();

            //把指针后移runLen距离,准备开始下一轮片段的排序
            lo += runLen;
            //剩下待排序的数量相应的减少 runLen
            nRemaining -= runLen;
        } while (nRemaining != 0);

        // 3 合并栈中所有待合并的序列
        assert lo == hi;
        ts.mergeForceCollapse();
        assert ts.stackSize == 1;
    }

1.1 先找自然升序序列

	/**
     * 这一段代码是TimSort算法中的一个小优化,它利用了数组中前面一段已有的顺序。
     * 如果是升序,直接返回统计结果;如果是降序,在返回之前,将这段数列倒置,
     * 以确保这断序列从首个位置到此位置的序列都是升序的。
     * 返回的结果是这种两种形式的,lo是这段序列的开始位置。
     * 为了保证排序的稳定性,这里要使用严格的降序,这样才能保证相等的元素不参与倒置子序列的过程,
     * 保证它们原本的顺序不被打乱。
     *
     * @param a  参与排序的数组
     * @param lo run中首个元素的位置
     * @param hi run中最后一个元素的后面一个位置,需要确保lo<hi
     * @param c  本次排序的比较器
     * @return 从首个元素开始的最长升序子序列的结尾位置+1 or 严格的降序子序列的结尾位置+1。
     */
	private static <T> int countRunAndMakeAscending(T[] a, int lo, int hi,
                                                    Comparator<? super T> c) {
        assert lo < hi;
        int runHi = lo + 1;
        if (runHi == hi)
            return 1;

        // 找出最长升序序的子序列,如果降序,倒置之
        if (c.compare(a[runHi++], a[lo]) < 0) { // 前两个元素是降序,就按照降序统计
            while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) < 0)
                runHi++;
            reverseRange(a, lo, runHi);
        } else {                              // 前两个元素是升序,按照升序统计
            while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) >= 0)
                runHi++;
        }

        return runHi - lo;
    }

1.2 二分插入排序

 	/**
     * 被优化的二分插入排序
     * 使用二分插入排序算法给指定一部分数组排序。这是给小数组排序的最佳方案。最差情况下
     * 它需要 O(n log n) 次比较和 O(n^2)次数据移动。
     * 如果开始的部分数据是有序的那么我们可以利用它们。这个方法默认数组中的位置lo(包括在内)到
     * start(不包括在内)的范围内是已经排好序的。
     *
     * @param a     被排序的数组
     * @param lo    待排序范围内的首个元素的位置
     * @param hi    待排序范围内最后一个元素的后一个位置
     * @param start 待排序范围内的第一个没有排好序的位置,确保 (lo <= start <= hi)
     * @param c     本次排序的比较器
     */
	private static <T> void binarySort(T[] a, int lo, int hi, int start,
                                       Comparator<? super T> c) {
        assert lo <= start && start <= hi;
        if (start == lo)
            start++;
        for ( ; start < hi; start++) {
        	//pivot 代表正在参与排序的值
            T pivot = a[start];

            //如果start 从起点开始,做下预处理;也就是原本就是无序的。
            int left = lo;
            int right = start;
            assert left <= right;
            /*
             * 利用二分查找,找到需要插入的位置,保证的逻辑:
             *   pivot >= all in [lo, left).
             *   pivot <  all in [right, start).
             */
            while (left < right) {
                int mid = (left + right) >>> 1;
                if (c.compare(pivot, a[mid]) < 0)
                    right = mid;
                else
                    left = mid + 1;
            }
            assert left == right;

             /**
             * 此时,仍然能保证:pivot >= [lo, left) && pivot < [left,start)
             * 所以,pivot的值应当在left所在的位置,然后需要把[left,start)范围内的内容整体右移一位腾出空间。
             * 如果pivot与区间中的某个值相等,left指正会指向重复的值的后一位(从left = mid + 1;这里可以看出),
             * 所以这里的排序是稳定的。
             */
            int n = start - left;  //需要移动的范围的长度
            // switch语句是一条小优化,1-2个元素的移动就不需要System.arraycopy了。
            // (这代码写的真是简洁,switch原来可以这样用)
            switch (n) {
                case 2:  a[left + 2] = a[left + 1];
                case 1:  a[left + 1] = a[left];
                         break;
                default: System.arraycopy(a, left, a, left + 1, n);
            }
            //移动过之后,把pivot的值放到应该插入的位置,就是left的位置了
            a[left] = pivot;
        }
    }

2.1 TimSort.minRunLength()切分数组,返回大小在[16,32)之间。

private static int minRunLength(int n) {
        assert n >= 0;
        int r = 0;      // Becomes 1 if any 1 bits are shifted off
        while (n >= MIN_MERGE) {
            r |= (n & 1); //计算最近一次n的末位数是1还是0
            n >>= 1; //缩小二倍,除以2
        }
        return n + r;
    }

2.2.3 检查是不是需要合并

 	/**
     * 检查栈中待归并的升序序列,如果他们不满足下列条件就把相邻的两个序列合并,
     * 直到他们满足下面的条件
     *
     * 1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
     * 2. runLen[i - 2] > runLen[i - 1]
     *
     * 每次添加新序列到栈中的时候都会执行一次这个操作。所以栈中的需要满足的条件
     * 需要靠调用这个方法来维护。
     *
     * 最差情况下,有点像玩2048。
     */
	private void mergeCollapse() {
        while (stackSize > 1) {
            int n = stackSize - 2;
            if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
                if (runLen[n - 1] < runLen[n + 1])
                    n--;
                mergeAt(n);
            } else if (runLen[n] <= runLen[n + 1]) {
                mergeAt(n);
            } else {
                break; // Invariant is established
            }
        }
    }

3 合并栈中所有待合并的序列

 	/**
     * 合并栈中所有待合并的序列,最后剩下一个序列。这个方法在整次排序中只执行一次
     */
    private void mergeForceCollapse() {
        while (stackSize > 1) {
            int n = stackSize - 2;
            if (n > 0 && runLen[n - 1] < runLen[n + 1])
                n--;
            mergeAt(n);
        }
    }

归并排序mergeAt

   /**
     * 合并在栈中位于i和i+1的两个相邻的升序序列。 i必须为从栈顶数,第二和第三个元素。
     * 换句话说i == stackSize - 2 || i == stackSize - 3
     *
     * @param i 待合并的第一个序列所在的位置
     */
	private void mergeAt(int i) {
		//校验
        assert stackSize >= 2;
        assert i >= 0;
        assert i == stackSize - 2 || i == stackSize - 3;
		//内部初始化
        int base1 = runBase[i];
        int len1 = runLen[i];
        int base2 = runBase[i + 1];
        int len2 = runLen[i + 1];
        assert len1 > 0 && len2 > 0;
        assert base1 + len1 == base2;

        /*
         * 记录合并后的序列的长度;如果i == stackSize - 3 就把最后一个序列的信息
         * 往前移一位,因为本次合并不关它的事。i+1对应的序列被合并到i序列中了,所以
         * i+1 数列可以消失了
         */
        runLen[i] = len1 + len2;
        if (i == stackSize - 3) {
            runBase[i + 1] = runBase[i + 2];
            runLen[i + 1] = runLen[i + 2];
        }
        //i+1消失了,所以长度也减下来了
        stackSize--;

        /*
         * 找出第二个序列的首个元素可以插入到第一个序列的什么位置,因为在此位置之前的序列已经就位了。
         * 它们可以被忽略,不参加归并。
         */
        int k = gallopRight(a[base2], a, base1, len1, 0, c);
        assert k >= 0;
        // 因为要忽略前半部分元素,所以起点和长度相应的变化
        base1 += k;
        len1 -= k;
        // 如果序列2 的首个元素要插入到序列1的后面,那就直接结束了,
        // !!! 因为序列2在数组中的位置本来就在序列1后面,也就是整个范围本来就是有序的!!!
        if (len1 == 0)
            return;

        /*
         * 跟上面相似,看序列1的最后一个元素(a[base1+len1-1])可以插入到序列2的什么位置(相对第二个序列起点的位置,非在数组中的位置),
         * 这个位置后面的元素也是不需要参与归并的。所以len2直接设置到这里,后面的元素直接忽略。
         */
        len2 = gallopLeft(a[base1 + len1 - 1], a, base2, len2, len2 - 1, c);
        assert len2 >= 0;
        if (len2 == 0)
            return;

        // 合并剩下的两个有序序列,并且这里为了节省空间,临时数组选用 min(len1,len2)的长度
        // 优化的很细呢
        if (len1 <= len2)
            mergeLo(base1, len1, base2, len2);
        else
            mergeHi(base1, len1, base2, len2);
    }

归并排序gallopLeft

/**
     * 在一个序列中,将一个指定的key,从左往右查找它应当插入的位置;如果序列中存在
     * 与key相同的值(一个或者多个),那返回这些值中最左边的位置。
     *
     * 推断: 统计概率的原因,随机数字来说,两个待合并的序列的尾假设是差不多大的,从尾开始
     * 做查找找到的概率高一些。仔细算一下,最差情况下,这种查找也是 log(n),所以这里没有
     * 用简单的二分查找。
     * 这里先简单的做了一个大概的范围锁定lastOfs到ofs,然后再从这个区间中用二分查找法去查
     *
     * @param key  准备插入的key
     * @param a    参与排序的数组
     * @param base 序列范围的第一个元素的位置
     * @param len  整个范围的长度,一定有len > 0
     * @param hint 开始查找的位置,有0 <= hint <= len;越接近结果查找越快
     * @param c    排序,查找使用的比较器
     * @return 返回一个整数 k, 有 0 <= k <=n, 它满足 a[b + k - 1] < a[b + k]
     * 就是说key应当被放在 a[base + k],
     * 有 a[base,base+k) < key && key <=a [base + k, base + len)
     */
    private static <T> int gallopLeft(T key, T[] a, int base, int len, int hint,
                                      Comparator<? super T> c) {
        assert len > 0 && hint >= 0 && hint < len;
        int lastOfs = 0;
        int ofs = 1;
        if (c.compare(key, a[base + hint]) > 0) { // key > a[base+hint]
            // 遍历右边,直到 a[base+hint+lastOfs] < key <= a[base+hint+ofs]
            int maxOfs = len - hint;
            while (ofs < maxOfs && c.compare(key, a[base + hint + ofs]) > 0) {
                lastOfs = ofs;
                ofs = (ofs << 1) + 1;
                if (ofs <= 0)   // int overflow
                    ofs = maxOfs;
            }
            if (ofs > maxOfs)
                ofs = maxOfs;

            // 最终的ofs是这样确定的,满足条件 a[base+hint+lastOfs] < key <= a[base+hint+ofs]
            // 的一组
            // ofs:     1   3   7  15  31  63 2^n-1 ... maxOfs
            // lastOfs: 0   1   3   7  15  31 2^(n-1)-1  < ofs


            // 因为目前的offset是相对hint的,所以做相对变换
            lastOfs += hint;
            ofs += hint;
        } else { // key <= a[base + hint]
            // 遍历左边,直到[base+hint-ofs] < key <= a[base+hint-lastOfs]
            final int maxOfs = hint + 1;
            while (ofs < maxOfs && c.compare(key, a[base + hint - ofs]) <= 0) {
                lastOfs = ofs;
                ofs = (ofs << 1) + 1;
                if (ofs <= 0)   // int overflow
                    ofs = maxOfs;
            }
            if (ofs > maxOfs)
                ofs = maxOfs;
            // 确定ofs的过程与上面相同
            // ofs:     1   3   7  15  31  63 2^n-1 ... maxOfs
            // lastOfs: 0   1   3   7  15  31 2^(n-1)-1  < ofs

            // Make offsets relative to base
            int tmp = lastOfs;
            lastOfs = hint - ofs;
            ofs = hint - tmp;
        }
        assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;

        /*
         * 现在的情况是 a[base+lastOfs] < key <= a[base+ofs], 所以,key应当在lastOfs的
         * 右边,又不超过ofs。在base+lastOfs-1到 base+ofs范围内做一次二叉查找。
         */
        lastOfs++;
        while (lastOfs < ofs) {
            int m = lastOfs + ((ofs - lastOfs) >>> 1);

            if (c.compare(key, a[base + m]) > 0)
                lastOfs = m + 1;  // a[base + m] < key
            else
                ofs = m;          // key <= a[base + m]
        }
        assert lastOfs == ofs;    // so a[base + ofs - 1] < key <= a[base + ofs]
        return ofs;
    }

归并排序mergeLo

	/**
     * 使用固定空间合并两个相邻的有序序列,保持数组的稳定性。
     * 使用本方法之前保证第一个序列的首个元素大于第二个序列的首个元素;第一个序列的末尾元素
     * 大于第二个序列的所有元素
     *
     * 为了性能,这个方法在len1 <= len2的时候调用;它的姐妹方法mergeHi应该在len1 >= len2
     * 的时候调用。len1==len2的时候随便调用哪个都可以
     *
     * @param base1 要合并的第一个run分隔槽中第一个元素的索引
     * @param len1  要合并的第一个run分隔槽的长度(必须大于0)
     * @param base2 要合并的第二个run分隔槽中第一个元素的索引
     *              (must be aBase + aLen)
     * @param len2  要合并的第二个run分隔槽的长度(必须大于0)
     */
    private void mergeLo(int base1, int len1, int base2, int len2) {
        assert len1 > 0 && len2 > 0 && base1 + len1 == base2;

        //将第一个序列放到临时数组中
        T[] a = this.a; // For performance
        T[] tmp = ensureCapacity(len1);
        System.arraycopy(a, base1, tmp, 0, len1);

        int cursor1 = 0;       // 临时数组指针
        int cursor2 = base2;   // 序列2的指针,参与归并的另一个序列
        int dest = base1;      // 保存结果的指针

        // 这里先把第二个序列的首个元素,移动到结果序列中的位置,然后处理那些不需要归并的情况
        a[dest++] = a[cursor2++];

        // 序列2只有一个元素的情况,把它移动到指定位置之后,剩下的临时数组
        // 中的所有序列1的元素全部copy到后面
        if (--len2 == 0) {
            System.arraycopy(tmp, cursor1, a, dest, len1);
            return;
        }
        // 序列1只有一个元素的情况,把它移动到最后一个位置,为了不覆盖,先把序列2中的元素
        // 全部移走。这个是因为序列1中的最后一个元素比序列2中的所有元素都大,这是该方法执行的条件
        if (len1 == 1) {
            System.arraycopy(a, cursor2, a, dest, len2);
            a[dest + len2] = tmp[cursor1]; // Last elt of run 1 to end of merge
            return;
        }

        Comparator<? super T> c = this.c;  // 本次排序的比较器

        int minGallop = this.minGallop;    //  "    "       "     "      "

        // 不了解break标签的同学要补补Java基本功了
        outer:
        while (true) {
            /*
            * 这里加了两个值来记录一个序列连续比另外一个大的次数,根据此信息,可以做出一些
            * 优化
            * */
            int count1 = 0; // 序列1 连续 比序列2大多少次
            int count2 = 0; // 序列2 连续 比序列1大多少次

            /*
            * 这里是直接的归并算法的合并的部分,这里会统计count1合count2,
            * 如果其中一个大于一个阈值,就会跳出循环
            * */
            do {
                assert len1 > 1 && len2 > 0;
                if (c.compare(a[cursor2], tmp[cursor1]) < 0) {
                    a[dest++] = a[cursor2++];
                    count2++;
                    count1 = 0;

                    // 序列2没有元素了就跳出整次合并
                    if (--len2 == 0)
                        break outer;
                } else {
                    a[dest++] = tmp[cursor1++];
                    count1++;
                    count2 = 0;
                    // 如果序列1只剩下最后一个元素了就可以跳出循环
                    if (--len1 == 1)
                        break outer;
                }

            /*
            * 这个判断相当于 count1 < minGallop && count2 <minGallop
            * 因为count1和count2总有一个为0
            * */
            } while ((count1 | count2) < minGallop);



            /*
             * 执行到这里的话,一个序列会连续的的比另一个序列大,那么这种连续性可能持续的
             * 更长。那么我们就按照这个逻辑试一试。直到这种连续性被打破。根据找到的长度,
             * 直接连续的copy就可以了,这样可以提高copy的效率。
             */
            do {
                assert len1 > 1 && len2 > 0;
                // gallopRight就是之前用过的那个方法
                count1 = gallopRight(a[cursor2], tmp, cursor1, len1, 0, c);
                if (count1 != 0) {
                    System.arraycopy(tmp, cursor1, a, dest, count1);
                    dest += count1;
                    cursor1 += count1;
                    len1 -= count1;
                    if (len1 <= 1) // 结尾处理退化的序列
                        break outer;
                }
                a[dest++] = a[cursor2++];
                if (--len2 == 0) //结尾处理退化的序列
                    break outer;

                count2 = gallopLeft(tmp[cursor1], a, cursor2, len2, 0, c);
                if (count2 != 0) {
                    System.arraycopy(a, cursor2, a, dest, count2);
                    dest += count2;
                    cursor2 += count2;
                    len2 -= count2;
                    if (len2 == 0)
                        break outer;
                }
                a[dest++] = tmp[cursor1++];
                if (--len1 == 1)
                    break outer;
                // 这里对连续性比另外一个大的阈值减少,这样更容易触发这段操作,
                // 应该是因为前面的数据表现好,后面的数据类似的可能性更高?
                minGallop--;
            } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); //如果连续性还是很大的话,继续这样处理s


            if (minGallop < 0)
                minGallop = 0;

            //同样,这里如果跳出了那段循环,就证明数据的顺序程度不好,应当增加阈值,避免浪费资源
            minGallop += 2;
        }  //outer 结束


        this.minGallop = minGallop < 1 ? 1 : minGallop;  // Write back to field

        //这里处理收尾工作
        if (len1 == 1) {
            assert len2 > 0;
            System.arraycopy(a, cursor2, a, dest, len2);
            a[dest + len2] = tmp[cursor1]; //  Last elt of run 1 to end of merge
        } else if (len1 == 0) {
            //因为序列1中的最后一个值,比序列2中的所有值都大,所以,不可能序列1空了,序列2还有元素
            throw new IllegalArgumentException(
                    "Comparison method violates its general contract!");
        } else {
            assert len2 == 0;
            assert len1 > 1;
            System.arraycopy(tmp, cursor1, a, dest, len1);
        }
    }

总结

      TimSort算法主要进行了以下优化
1、利用自然升序序列
2、优化的二分插入排序,先利用二分查找找到位置,再进行移位,插入数据排序
3、拆分为大小查不多的run分割槽。因为将一个长序列和一个短序列进行归并排序从效率和代价的角度来看是不划算的,而两个长度均衡的序列进行归并排序时才是比较合理的也比较高效的。
4、优化的归并排序,
4.1 找到在左边run1中找到右边run2中的最小元素的位置x1,这样x1之前的元素就不用处理了;
4.2 找到在右边run2中找到左边run1中的最大元素的位置x2,这样x2之后的元素就不用处理了;
4.3 合并x1到x2之间的数
4.3.1 若右边只有一个,说明这个小于左边所有元素,直接插入最前
4.3.2 若左边只有一个,说明这个大于右边所有元素,直接放最后
4.3.3 多对多的情况,这里就要优化一下了。一个序列会连续的的比另一个序列大,那么这种连续性可能持续的更长。所以这里增加了一个连续性的监控,如若连续性大于7(这里是默认值,会根据情况调整),则改用4.1和4.2中的处理方式进行合并,否则进行一个个比较。
4.3.4 最后的收尾工作也进行了判断,左边剩一个直接放最后;左边剩多个,右边不剩,将左边剩的都放最后。

      Timsort是稳定的算法,当待排序的数组中已经有排序好的数,它的时间复杂度会小于nlogn。与其他合并排序一样,Timesort是稳定的排序算法,最坏时间复杂度是O(n log n)。在最坏情况下,Timsort算法需要的临时空间是n/2,在最好情况下,它只需要一个很小的临时存储空间。

Timsort原理介绍

TimSort算法 源码

猜你喜欢

转载自blog.csdn.net/u010168160/article/details/107608576