spark的TimSort排序算法实现

Spark版本2.4.0。

 

Spark中的排序实现也是通过TimSort类实现,实现具体方式与JDK略有区别。

 

具体实现,在TimSort类的sort()方法的sort()方法中。

if (nRemaining < MIN_MERGE) {
  int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
  binarySort(a, lo, hi, lo + initRunLen, c);
  return;
}

当被排序的数组长度小于32时,具体的排序流程分为两步,首先通过countRunAndMakeAscending()方法标出从当前数组排序起点开始最长的一段单调区间并调均将这段区间调整为增序。

private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator<? super K> c) {
  assert lo < hi;
  int runHi = lo + 1;
  if (runHi == hi)
    return 1;

  K key0 = s.newKey();
  K key1 = s.newKey();

  // Find end of run, and reverse range if descending
  if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending
    while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0)
      runHi++;
    reverseRange(a, lo, runHi);
  } else {                              // Ascending
    while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0)
      runHi++;
  }

  return runHi - lo;
}

这里的代码可以看到,先从起点下标lo开始与其下一个下标位置runHi=lo+1进行比较,如果lo+1位置的数据小于lo,那么此时确认lo位置开始寻找lo位置开始的第一段递减区间,直到该递减区间结束,之后逆置这段区间为增序,记下区间结束下标,该下标与lo之差即为该段递增区间的长度。而如果一开始就为递增区间,那么就一直记录到该段递增结束即可,并按照上述操作返回该段区间总长度。

private void binarySort(Buffer a, int lo, int hi, int start, Comparator<? super K> c) {
  assert lo <= start && start <= hi;
  if (start == lo)
    start++;

  K key0 = s.newKey();
  K key1 = s.newKey();

  Buffer pivotStore = s.allocate(1);
  for ( ; start < hi; start++) {
    s.copyElement(a, start, pivotStore, 0);
    K pivot = s.getKey(pivotStore, 0, key0);

    // Set left (and right) to the index where a[start] (pivot) belongs
    int left = lo;
    int right = start;
    assert left <= right;
    /*
     * Invariants:
     *   pivot >= all in [lo, left).
     *   pivot <  all in [right, start).
     */
    while (left < right) {
      int mid = (left + right) >>> 1;
      if (c.compare(pivot, s.getKey(a, mid, key1)) < 0)
        right = mid;
      else
        left = mid + 1;
    }
    assert left == right;

    /*
     * The invariants still hold: pivot >= all in [lo, left) and
     * pivot < all in [left, start), so pivot belongs at left.  Note
     * that if there are elements equal to pivot, left points to the
     * first slot after them -- that's why this sort is stable.
     * Slide elements over to make room for pivot.
     */
    int n = start - left;  // The number of elements to move
    // Switch is just an optimization for arraycopy in default case
    switch (n) {
      case 2:  s.copyElement(a, left + 1, a, left + 2);
      case 1:  s.copyElement(a, left, a, left + 1);
        break;
      default: s.copyRange(a, left, a, left + 1, n);
    }
    s.copyElement(pivotStore, 0, a, left);
  }
}

之后就会通过binarySort()方法进行二分插入排序,由于已经通过之前的方式得到了从开始位置的一段递增区间,那么就可以以此为基础开始从递增区间后的下一个元素开始往前插入,通过二分法确定插入位置,直到全部插入完毕,这是被排序数组小于32时候的情况。

 

接下来是大于32的情况,会采用归并排序。

在这里的归并排序,会通过一个SortState的内部类来辅助完成归并排序,在这里开始会初始化一个SortState。

runBase = new int[stackLen];
runLen = new int[stackLen];

SortState实际维护了两个堆栈,一个堆栈用来存储一个递增区间在当前数组中的起始下标,另一个则代表该区间的长度,通过堆栈的下标来定义一个在被维护数组中的有序区间。用来归并排序。

SortState sortState = new SortState(a, c, hi - lo);
int minRun = minRunLength(nRemaining);

通过minRunLength()方法确认,以被排序的数组长度得到一次最小的排序长度。

private int minRunLength(int n) {
  assert n >= 0;
  int r = 0;      // Becomes 1 if any 1 bits are shifted off
  while (n >= MIN_MERGE) {
    r |= (n & 1);
    n >>= 1;
  }
  return n + r;
}

最小数量根据具体的长度在16到32的区间。

do {
  // Identify next run
  int runLen = countRunAndMakeAscending(a, lo, hi, c);

  // If run is short, extend to min(minRun, nRemaining)
  if (runLen < minRun) {
    int force = nRemaining <= minRun ? nRemaining : minRun;
    binarySort(a, lo, lo + force, lo + runLen, c);
    runLen = force;
  }

  // Push run onto pending-run stack, and maybe merge
  sortState.pushRun(lo, runLen);
  sortState.mergeCollapse();

  // Advance to find next run
  lo += runLen;
  nRemaining -= runLen;
} while (nRemaining != 0);

之后按照小于32时候的场景,首先根据countRunAndMakeAscending()方法得到排序当前其实位置下标下的一段递增区间,如果该区间小于最小排序长度那么先得到剩数组未排序余排序长度和最小排序的最小者,先根据之前的binartSort()进行二分插入排序,使得在接下来的排序中,下一段递增区间为此次发生的二分插入排序的排序长度。

在得到上述的一个有序递增区间之后,则将该区间的起始位置和长度分别压入堆栈,并通过mergeCollapse()方法进行归并排序。

private void mergeCollapse() {
  while (stackSize > 1) {
    int n = stackSize - 2;
    if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1])
      || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) {
      if (runLen[n - 1] < runLen[n + 1])
        n--;
    } else if (runLen[n] > runLen[n + 1]) {
      break; // Invariant is established
    }
    mergeAt(n);
  }
}

是否采取要开始排序在这里判断,如果当SortState中的递增区间数量为1,也没有必要进行排序,否则只有在堆栈数量为2,且栈底长度大于栈顶的时候才不会进行归并。

归并排序主要分为三个步骤。

int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c);
assert k >= 0;
base1 += k;
len1 -= k;
if (len1 == 0)
  return;

/*
 * Find where the last element of run1 goes in run2. Subsequent elements
 * in run2 can be ignored (because they're already in place).
 */
len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c);
assert len2 >= 0;
if (len2 == 0)
  return;

// Merge remaining runs, using tmp array with min(len1, len2) elements
if (len1 <= len2)
  mergeLo(base1, len1, base2, len2);
else
  mergeHi(base1, len1, base2, len2);

首先,在gallopRight()方法中,为第一步。

int maxOfs = len - hint;
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) {
  lastOfs = ofs;
  ofs = (ofs << 1) + 1;
  if (ofs <= 0)   // int overflow
    ofs = maxOfs;
}
if (ofs > maxOfs)
  ofs = maxOfs;

// Make offsets relative to b
lastOfs += hint;
ofs += hint;

在这里的一步,主要是从以第二个区间第一个元素为基准找到第一个递增区间中首个大于该元素的下标,由于区间都是递增,可以保证该元素之前都是小于第二个区间的第一个元素的,所以第一个区间的这一段元素不需要参与接下来的归并排序。

同样的第二步,gallopLeft()的操作类似,选取第一个区间的最后一个元素插入到从第二个区间左边最大值开始寻找首个小于该元素的值,这样可以保证左端都是大于第一个区间最后一个元素的,同样的第二个区间的这一段不用参与归并排序。

最后参与到第三步,归并排序的只剩下第一个区间的右边和第二个区间的左边,根据这两端的长度确定是采用mergeLo()还是mergeHi()方法进行排序。

在排序的过程中还有一些细节,比如在归并中,如果一次性将超过7个的第一个区间插入到结果数组,那么说明此时第一个区间中元素很有可能在接下来的一段时间都过小,重复之前的第一步操作可以使得性能大大提高。

 

 

在完成SortState的归并排序之后,回到TimSort的循环中,会继续按照之前的规则寻找下一段递增区间放入堆栈中准备排序直到数组被全部划分为递增区间存在堆栈中,并在结束后调用SortState的mergeForceCollapse()方法将堆栈中剩余的所有区间归并排序完毕,一次排序也宣告结束。

 

 

发布了141 篇原创文章 · 获赞 19 · 访问量 10万+

猜你喜欢

转载自blog.csdn.net/weixin_40318210/article/details/97338600