编程语言的sort方法实现

sort实现

  • PHP
  • JAVA
  • Python
  • Golang

PHP

版本 8.1

内部实现机制为:快速排序、插入排序

  • 小于 6 个元素时,直接if else对比排序
  • 等于 6 个元素时,采用直接插入排序,避免快排的空间消耗
  • 大于 6 到 小于等于 16 个元素时,采用siz2 插入排序
  • 大于 16 个元素,快速排序

快排,在数组中选一个元素作为中心元素,把数组元素划分为小于中心元素的集合以及大于中心元素的集合,两个集合放到中心元素的两边。然后再对两个集合分别划分,一直划分到元素个数为 1 或 0 的时候才停下。

左右指针法

  1. 选取一个关键字(key)作为枢轴
  2. 设置两个变量left = 0;right = N - 1
  3. left一直向后走,直到找到一个大于key的值,right从后至前,直至找到一个小于key的值,然后交换这两个数
  4. 重复第三步,一直往后找,直到left和right相遇,这时将key放置left的位置即可
  5. 再对左右区间重复第二步到第四步,直到各区间只有一个数(递归)

bool sort ( array &$array [, int $sort_flags = SORT_REGULAR ] )

传递一个数组作为第一个参数,并且该参数是一个引用参数,不需要设置返回值,排序结果直接作用在数组本身上。该函数默认对数组内数据进行升序排序,支持对数字和字符串进行排序。

排序算法来自 LLVM 项目的 libc++ 实现

https://github.com/llvm/llvm-project/blob/main/libcxx/include/__algorithm/sort.h#L272

template <class _Compare, class _RandomAccessIterator>
void __sort(_RandomAccessIterator __first, _RandomAccessIterator __last, _Compare __comp)

https://github.com/php/php-src/blob/master/ext/standard/array.c

ext/standard/array.c

PHP_FUNCTION(sort)
{
	...
	zend_hash_sort(Z_ARRVAL_P(array), cmp, 1);

	RETURN_TRUE;
}

Zend/zend_hash.h

ZEND_API void  ZEND_FASTCALL zend_hash_sort_ex(HashTable *ht, sort_func_t sort_func, bucket_compare_func_t compare_func, bool renumber);

#define zend_hash_sort(ht, compare_func, renumber) \
	zend_hash_sort_ex(ht, zend_sort, compare_func, renumber)

Zend/zend_hash.c

ZEND_API void ZEND_FASTCALL zend_hash_sort_ex(HashTable *ht, sort_func_t sort, bucket_compare_func_t compar, bool renumber)
{
	...

    // sort 方法,传的 zend_sort,完成排序
	sort((void *)ht->arData, ht->nNumUsed, sizeof(Bucket), (compare_func_t) compar,
			(swap_func_t)(renumber? zend_hash_bucket_renum_swap :
				((HT_FLAGS(ht) & HASH_FLAG_PACKED) ? zend_hash_bucket_packed_swap : zend_hash_bucket_swap)));

	...
}

Zend/zend_sort.h

ZEND_API void zend_sort(void *base, size_t nmemb, size_t siz, compare_func_t cmp, swap_func_t swp);
ZEND_API void zend_insert_sort(void *base, size_t nmemb, size_t siz, compare_func_t cmp, swap_func_t swp);

Zend/zend_sort.c

static inline void zend_sort_2(void *a, void *b, compare_func_t cmp, swap_func_t swp) {
	if (cmp(a, b) > 0) {
		swp(a, b);
	}
}

static inline void zend_sort_3(void *a, void *b, void *c, compare_func_t cmp, swap_func_t swp) {
	if (!(cmp(a, b) > 0)) {
		if (!(cmp(b, c) > 0)) {
			return;
		}
		swp(b, c);
		if (cmp(a, b) > 0) {
			swp(a, b);
		}
		return;
	}
	if (!(cmp(c, b) > 0)) {
		swp(a, c);
		return;
	}
	swp(a, b);
	if (cmp(b, c) > 0) {
		swp(b, c);
	}
}

static void zend_sort_4(void *a, void *b, void *c, void *d, compare_func_t cmp, swap_func_t swp) {
	zend_sort_3(a, b, c, cmp, swp);
	if (cmp(c, d) > 0) {
		swp(c, d);
		if (cmp(b, c) > 0) {
			swp(b, c);
			if (cmp(a, b) > 0) {
				swp(a, b);
			}
		}
	}
}

static void zend_sort_5(void *a, void *b, void *c, void *d, void *e, compare_func_t cmp, swap_func_t swp) {
	zend_sort_4(a, b, c, d, cmp, swp);
	if (cmp(d, e) > 0) {
		swp(d, e);
		if (cmp(c, d) > 0) {
			swp(c, d);
			if (cmp(b, c) > 0) {
				swp(b, c);
				if (cmp(a, b) > 0) {
					swp(a, b);
				}
			}
		}
	}
}

ZEND_API void zend_insert_sort(void *base, size_t nmemb, size_t siz, compare_func_t cmp, swap_func_t swp){
	switch (nmemb) {
		case 0:
		case 1:
			break;
		case 2:
			zend_sort_2(base, (char *)base + siz, cmp, swp);
			break;
		case 3:
			zend_sort_3(base, (char *)base + siz, (char *)base + siz + siz, cmp, swp);
			break;
		case 4:
			{
				size_t siz2 = siz + siz;
				zend_sort_4(base, (char *)base + siz, (char *)base + siz2, (char *)base + siz + siz2, cmp, swp);
			}
			break;
		case 5:
			{
				size_t siz2 = siz + siz;
				zend_sort_5(base, (char *)base + siz, (char *)base + siz2, (char *)base + siz + siz2, (char *)base + siz2 + siz2, cmp, swp);
			}
			break;
		default:
			{
				char *i, *j, *k;
				char *start = (char *)base;
				char *end = start + (nmemb * siz);
				size_t siz2= siz + siz;
				char *sentry = start + (6 * siz); // 设置为第7个元素

                // 等于 6 个元素时,直接插入排序
				for (i = start + siz; i < sentry; i += siz) {
                    // j 指向 有序区 的最后一个元素
					j = i - siz;
                    // i > j,已经有序,继续考察 无序区 下一个元素
					if (!(cmp(j, i) > 0)) {
						continue;
					}
                    // 比较有序区
					while (j != start) {
						j -= siz;
                        // i > j,j + siz 跳出循环
						if (!(cmp(j, i) > 0)) {
							j += siz;
							break;
						}
					}
                    // 遍历,交换元素,为什么比较和交换要分开遍历两次?
					for (k = i; k > j; k -= siz) {
						swp(k, k - siz);
					}
				}

                // 大于 6 个元素时,再从哨兵 sentry 指向的第 7 个元素开始,进行变种的插入排序(简单来说,就是指针每次往前移动 2 位,去对比寻找待排元素所属的位置,而正常情况是每次移动 1 位)
				for (i = sentry; i < end; i += siz) {
                    // j 指向 有序区 的最后一个元素
					j = i - siz;
                    // i > j,已经有序,继续考察 无序区 下一个元素
					if (!(cmp(j, i) > 0)) {
						continue;
					}
                    // 比较有序区,但是 siz -> siz2,要判断两次
					do {
						j -= siz2;
						if (!(cmp(j, i) > 0)) {
							j += siz;
							if (!(cmp(j, i) > 0)) {
								j += siz;
							}
							break;
						}
                        // 边界处理,j 指向第 1 个元素,i 比第 1 个元素还小,退出循环
						if (j == start) {
							break;
						}
                        // 边界处理,j 指向第 2 个元素,不能再往前 2 步,只能指向第 1 个元素
						if (j == start + siz) {
							j -= siz;
							if (cmp(i, j) > 0) {
								j += siz;
							}
							break;
						}
					} while (1);
					for (k = i; k > j; k -= siz) {
						swp(k, k - siz);
					}
				}
			}
			break;
	}
}

ZEND_API void zend_sort(void *base, size_t nmemb, size_t siz, compare_func_t cmp, swap_func_t swp)
{
	while (1) {
		if (nmemb <= 16) {
            // 插入排序
			zend_insert_sort(base, nmemb, siz, cmp, swp);
			return;
		} else {
            // 快排指针交换法
			char *i, *j;
			char *start = (char *)base;
			char *end = start + (nmemb * siz);
            // 取数组中间值
			size_t offset = (nmemb >> Z_L(1));
            // 分区点 pivot
			char *pivot = start + (offset * siz);

            // 元素个数 nmemb 右移 10 位后仍为真,个数大于等于 1024 时,将数组做两次平分。个数非常大时,可能会因为划分后元素个数仍大于 1024 而再取五数中值。
			if ((nmemb >> Z_L(10))) {
                // 对 offset 再取中间值
				size_t delta = (offset >> Z_L(1)) * siz;
                // 五数取中法,取三个中间数和两边缘数,取中间值作为分区点 pivot
				zend_sort_5(start, start + delta, pivot, pivot + delta, end - siz, cmp, swp);
			} else {
                // 个数小于 1024 时,将数组做一次平分
                // 三数取中法,取中间数和两边缘数,
				zend_sort_3(start, pivot, end - siz, cmp, swp);
			}
            // 将分区点的值与第二个元素的值互换
			swp(start + siz, pivot);
            // 分区点 pivot 指针指向第二个元素
			pivot = start + siz;
            // 对 pivot 后面的元素进行分区,目的是找到 pivot 应该放置的位置,以下假设是从小到大排序,则比 pivot 小的都要在它左边,比 pivot 大的都要在它右边
			i = pivot + siz;
			j = end - siz;
			while (1) {
                // 指针 i 除非发现比 pivot 大的元素,否则一直往右移动
				while (cmp(pivot, i) > 0) {
					i += siz;
					if (UNEXPECTED(i == j)) {
						goto done;
					}
				}
                // 指针 j 往左移动,继续寻找
				j -= siz;
				if (UNEXPECTED(j == i)) {
					goto done;
				}
                // 指针 j 除非发现比 pivot 小的元素,否则一直往左移动
				while (cmp(j, pivot) > 0) {
					j -= siz;
					if (UNEXPECTED(j == i)) {
						goto done;
					}
				}
                // 交换
				swp(i, j);
                // 指针 i 往右移动,继续寻找
				i += siz;
				if (UNEXPECTED(i == j)) {
					goto done;
				}
			}
done:
            // 指针 i 与 j 相遇,同时指向 右分区 的第一个元素。将元素 i-1(即左分区的最后一个元素)的值与 pivot 值互换,至此实现 pivot 左边的元素都比它小,右边的都比它大
			swp(pivot, i - siz);
            // 比较左右分区的元素个数,元素较少的分区递归调用 zend_sort 方法,较多的分区继续最外层的 while 循环
			if ((i - siz) - start < end - i) {
				zend_sort(start, (i - start)/siz - 1, siz, cmp, swp);
				base = i;
				nmemb = (end - i)/siz;
			} else {
				zend_sort(i, (end - i)/siz, siz, cmp, swp);
				nmemb = (i - start)/siz - 1;
			}
		}
	}
}

JAVA

版本 1.8

内部实现机制为:TimSort、插入排序

快速排序是不稳定的,而Timsort是稳定的。

public class ArrayList<E> extends AbstractList<E>
        implements List<E>, RandomAccess, Cloneable, java.io.Serializable
{
    public void sort(Comparator<? super E> c) {
        final int expectedModCount = modCount;
        Arrays.sort((E[]) elementData, 0, size, c);
        if (modCount != expectedModCount) {
            throw new ConcurrentModificationException();
        }
        modCount++;
    }
}

// if判断modCount != expectedModCount,expectedModCount的值是从modCount来在这段代码中并未对两个变量进行修改。如果是单线程中这的确是冗余的,但是一旦到了多线程中,其他线程对这个ArrayList的实例进行了一个add remove等操作,改变了ArrayList数据结构。那么modCount就改变了,同样modCount != expectedModCount也就成立了,本次遍历有误,抛出异常。ArrayList是线程不安全的。判断操作过程是否被修改。

public class Arrays {
    // 指定整数值数组的指定范围按升序排序。要排序的范围从索引fromIndex(包括)到索引toIndex(不包括)。默认0到int最大值。Comparable接口,比较器接口,自定义比较。
    public static <T> void sort(T[] a, int fromIndex, int toIndex,
                                Comparator<? super T> c) {
        if (c == null) {
            sort(a, fromIndex, toIndex);
        } else {
            rangeCheck(a.length, fromIndex, toIndex);
            if (LegacyMergeSort.userRequested)
                legacyMergeSort(a, fromIndex, toIndex, c);
            else
                TimSort.sort(a, fromIndex, toIndex, c, null, 0, 0);
        }
    }

    // 算法改编自 TimSort https://svn.python.org/projects/python/trunk/Objects/listsort.txt
    // Timsort 是Python的标准排序算法
    public static void sort(Object[] a, int fromIndex, int toIndex) {
        // 检查fromIndex, toIndex是否合法
        rangeCheck(a.length, fromIndex, toIndex);
        // 系统属性配置的旧比较方法,将被删除
        if (LegacyMergeSort.userRequested)
            legacyMergeSort(a, fromIndex, toIndex);
        else
            ComparableTimSort.sort(a, fromIndex, toIndex, null, 0, 0);
    }  
}

class ComparableTimSort {
    static void sort(Object[] a, int lo, int hi, Object[] work, int workBase, int workLen) {
        assert a != null && lo >= 0 && lo <= hi && hi <= a.length;

        int nRemaining  = hi - lo;
        // 判断数组长度是否小于2 如果是只有0或1,这种数组通常已经被排序
        if (nRemaining < 2)
            return;  // Arrays of size 0 and 1 are always sorted

        // If array is small, do a "mini-TimSort" with no merges
        // 如果数组长度小于MIN_MERGE(32)则使用二分插入排序
        if (nRemaining < MIN_MERGE) {
            // 获取需要比较的范围长度
            int initRunLen = countRunAndMakeAscending(a, lo, hi);
            binarySort(a, lo, hi, lo + initRunLen);
            return;
        }

        /**
         * March over the array once, left to right, finding natural runs,
         * extending short natural runs to minRun elements, and merging runs
         * to maintain stack invariant.
         */
        ComparableTimSort ts = new ComparableTimSort(a, work, workBase, workLen);
        // TimSort 计算 minRun的值,最小分片长度
        int minRun = minRunLength(nRemaining);
        do {
            // Identify next run
            // 获得一个最长递增序列,有序片段长度
            int runLen = countRunAndMakeAscending(a, lo, hi);

            // If run is short, extend to min(minRun, nRemaining)
            // 原始的run小于minrun的长度,用二分插入排序扩充run,直到达到条件
            if (runLen < minRun) {
                int force = nRemaining <= minRun ? nRemaining : minRun;
                binarySort(a, lo, lo + force, lo + runLen);
                runLen = force;
            }

            // Push run onto pending-run stack, and maybe merge
            // 用一个栈来保存每个run
            ts.pushRun(lo, runLen);
            // 归并来合并多个run
            ts.mergeCollapse();

            // Advance to find next run
            // 下一分片
            lo += runLen;
            nRemaining -= runLen;
        } while (nRemaining != 0);

        // Merge all remaining runs to complete sort
        assert lo == hi;
        // 合并剩余的run
        ts.mergeForceCollapse();
        assert ts.stackSize == 1;
    }

    // 一次运行是最长的上升序列,具有:A[lo]<=A[lo+1]<=A[lo+2]<=...或者最长的降序,然后翻转:a[lo]>a[lo+1]>a[lo+2]>...,返回有序片段长度
    private static int countRunAndMakeAscending(Object[] a, int lo, int hi) {
        assert lo < hi;
        int runHi = lo + 1;
        if (runHi == hi)
            return 1;

        // Find end of run, and reverse range if descending
        if (((Comparable) a[runHi++]).compareTo(a[lo]) < 0) { // Descending
            while (runHi < hi && ((Comparable) a[runHi]).compareTo(a[runHi - 1]) < 0)
                runHi++;
            reverseRange(a, lo, runHi);
        } else {                              // Ascending
            while (runHi < hi && ((Comparable) a[runHi]).compareTo(a[runHi - 1]) >= 0)
                runHi++;
        }

        return runHi - lo;
    }

    // 长度小于32时的二分插入排序,对少量元素进行排序的最佳方法。它需要O(n logn)比较,但O(n^2)数据移动(最坏情况)。如果指定范围的初始部分已排序,则此方法可以利用它(指把无序插入到已排序):该方法假定索引lo(包含)、开始和排除中的元素已排序。开始–范围内尚未被排序的第一个元素的索引(lo<=start<=hi)
    private static void binarySort(Object[] a, int lo, int hi, int start) {
        assert lo <= start && start <= hi;
        if (start == lo)
            start++;
        for ( ; start < hi; start++) {
            Comparable pivot = (Comparable) a[start];

            // Set left (and right) to the index where a[start] (pivot) belongs
            int left = lo;
            int right = start;
            assert left <= right;
            /*
             * Invariants:
             *   pivot >= all in [lo, left).
             *   pivot <  all in [right, start).
             */
            while (left < right) {
                // 右移运算,相当于mid = (left + right) / 2,但更快
                int mid = (left + right) >>> 1;
                // 要排序的值和中间值比较
                // a[start] - a[mid] < 0
                if (pivot.compareTo(a[mid]) < 0)
                    right = mid;
                else
                    left = mid + 1;
            }
            // 折半再折半
            assert left == right;

            /*
             * The invariants still hold: pivot >= all in [lo, left) and
             * pivot < all in [left, start), so pivot belongs at left.  Note
             * that if there are elements equal to pivot, left points to the
             * first slot after them -- that's why this sort is stable.
             * Slide elements over to make room for pivot.
             */
            int n = start - left;  // The number of elements to move
            // Switch is just an optimization for arraycopy in default case
            switch (n) {
                case 2:  a[left + 2] = a[left + 1];
                case 1:  a[left + 1] = a[left];
                         break;
                default: System.arraycopy(a, left, a, left + 1, n);
            }
            // 把left右边的移动一位,left给pivot
            a[left] = pivot;
        }
    }

    
    private static final int MIN_MERGE = 32;
    // TimSort排序,充分利用了待排序数据里面,有很多子串是已经排好序的不需要再重新排序。
    // 当数组元素个数大于64时,如前所述, 我们知道当 run 的数目等于或略小于2的幂时,合并两个数组最为有效。所以 Timsort 选择范围为 [16,32]的 minrun,使得原始数组的长度n 除以 minrun 时,等于或略小于2的幂。
    // 具体而言,选择数组长度的五个最高标志位,如果其余的标志位被设置,则加1
    // 189:10111101,取前五个最高标志位为10111(23) ,同时后1位为1,所以 minrun 为23+1=24,189/24 = 8
    // 976:1111010000,取前五个最高标志位为11110(30),同时后1位为1,所以 minrun 为30+1=31,976/31 = 32
    // 为什么需要minrun这个值来约束run的大小,因为这样做可以使得run的大小保持一个均衡,避免存在一个很短的run和一个很长的run进行合并。
    private static int minRunLength(int n) {
        assert n >= 0;
        int r = 0;      // Becomes 1 if any 1 bits are shifted off
        //n >= 32,一直 n / 2,直到小于32。n为2的N幂。 n &1之后,n为奇数则为1,偶数为0。 r代表最后一次移位n时,n最低位是0还是1。
        while (n >= MIN_MERGE) {
            r |= (n & 1);
            n >>= 1;
        }
        return n + r;
    }

    private void pushRun(int runBase, int runLen) {
        this.runBase[stackSize] = runBase;
        this.runLen[stackSize] = runLen;
        stackSize++;
    }

    // 1. 只对相邻的区块merge
    // 2. 若当前区块数仅为2,If X<=Y,将X和Y merge (X为栈底)
    // 3. 若当前区块数>=3,If X<=Y+Z,将X和Y merge,直到同时满足X>Y+Z和Y>Z
    private void mergeCollapse() {
        while (stackSize > 1) {
            int n = stackSize - 2;
            // 栈顶的3个片段比较,runLen[n-1] <= runLen[n] + runLen[n+1] 或者 runLen[n] <= runLen[n + 1],合并两个run
            if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
                if (runLen[n - 1] < runLen[n + 1])
                    n--;
                mergeAt(n);
            } else if (runLen[n] <= runLen[n + 1]) {
                mergeAt(n);
            } else {
                break; // Invariant is established
            }
        }
    }

    private void mergeAt(int i) {
        assert stackSize >= 2;
        assert i >= 0;
        assert i == stackSize - 2 || i == stackSize - 3;

        int base1 = runBase[i];
        int len1 = runLen[i];
        int base2 = runBase[i + 1];
        int len2 = runLen[i + 1];
        assert len1 > 0 && len2 > 0;
        assert base1 + len1 == base2;

        /*
         * Record the length of the combined runs; if i is the 3rd-last
         * run now, also slide over the last run (which isn't involved
         * in this merge).  The current run (i+1) goes away in any case.
         */
        runLen[i] = len1 + len2; // 合并长度
        if (i == stackSize - 3) {
            runBase[i + 1] = runBase[i + 2];
            runLen[i + 1] = runLen[i + 2];
        }
        stackSize--;

        /*
         * 查找到run2的第一个元素a[base2]排序在run1的位置,run1之前的元素不用处理
         * Find where the first element of run2 goes in run1. Prior elements
         * in run1 can be ignored (because they're already in place).
         */
        int k = gallopRight((Comparable<Object>) a[base2], a, base1, len1, 0);
        assert k >= 0;
        base1 += k;
        len1 -= k; // 需要处理的len1 长度
        if (len1 == 0)
            return;

        /*
         * 查找到run1最后一个元素a[base1 + len1 - 1]排序在run2的位置,run2之后的元素不用处理
         * Find where the last element of run1 goes in run2. Subsequent elements
         * in run2 can be ignored (because they're already in place).
         */
        len2 = gallopLeft((Comparable<Object>) a[base1 + len1 - 1], a,
                base2, len2, len2 - 1);
        assert len2 >= 0; // 需要处理len2长度
        if (len2 == 0)
            return;

        // 把min(len1, len2) 复制到临时array,合并
        // Merge remaining runs, using tmp array with min(len1, len2) elements
        if (len1 <= len2)
            mergeLo(base1, len1, base2, len2);
        else
            mergeHi(base1, len1, base2, len2);
    }

    // 寻找run1的最后一个元素应当插入run2中哪个位置,然后就可以忽略之后run2的元素(都比run1的最后一个元素大)
    private static int gallopLeft(Comparable<Object> key, Object[] a,
            int base, int len, int hint) {
        assert len > 0 && hint >= 0 && hint < len;

        int lastOfs = 0;
        int ofs = 1;
        if (key.compareTo(a[base + hint]) > 0) {
            // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs]
            int maxOfs = len - hint;
            while (ofs < maxOfs && key.compareTo(a[base + hint + ofs]) > 0) {
                lastOfs = ofs;
                ofs = (ofs << 1) + 1;
                if (ofs <= 0)   // int overflow
                    ofs = maxOfs;
            }
            if (ofs > maxOfs)
                ofs = maxOfs;

            // Make offsets relative to base
            lastOfs += hint;
            ofs += hint;
        } else { // key <= a[base + hint]
            // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs]
            final int maxOfs = hint + 1;
            while (ofs < maxOfs && key.compareTo(a[base + hint - ofs]) <= 0) {
                lastOfs = ofs;
                ofs = (ofs << 1) + 1;
                if (ofs <= 0)   // int overflow
                    ofs = maxOfs;
            }
            if (ofs > maxOfs)
                ofs = maxOfs;

            // Make offsets relative to base
            int tmp = lastOfs;
            lastOfs = hint - ofs;
            ofs = hint - tmp;
        }
        assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;

        /*
         * Now a[base+lastOfs] < key <= a[base+ofs], so key belongs somewhere
         * to the right of lastOfs but no farther right than ofs.  Do a binary
         * search, with invariant a[base + lastOfs - 1] < key <= a[base + ofs].
         */
        lastOfs++;
        while (lastOfs < ofs) {
            int m = lastOfs + ((ofs - lastOfs) >>> 1);

            if (key.compareTo(a[base + m]) > 0)
                lastOfs = m + 1;  // a[base + m] < key
            else
                ofs = m;          // key <= a[base + m]
        }
        assert lastOfs == ofs;    // so a[base + ofs - 1] < key <= a[base + ofs]
        return ofs;
    }

    // 在 Galloping mode 中,算法在一个 run 中搜索另一个 run 的第一个元素。通过将该初始元素与另一个 run 的第2 k − 1 2k-12k−1个元素(即1,3,5…)进行比较来完成的,以便获得初始元素所在的元素范围。这缩短了二分查找的范围,从而提高了效率。如果发现 Galloping 的效率低于二分查找,则退出 Galloping mode。
    // gallopRight 寻找run2的第一个元素应当插入run1中哪个位置,然后就可以忽略之前run1的元素(都比run2的第一个元素小)
    private static int gallopRight(Comparable<Object> key, Object[] a,
            int base, int len, int hint) {
        assert len > 0 && hint >= 0 && hint < len;

        int ofs = 1;
        int lastOfs = 0;
        // run2的第一个值key,在Gallop左边
        if (key.compareTo(a[base + hint]) < 0) {
            // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs]
            int maxOfs = hint + 1;
            while (ofs < maxOfs && key.compareTo(a[base + hint - ofs]) < 0) {
                lastOfs = ofs;
                ofs = (ofs << 1) + 1; // ofs = 2 * ofs + 1
                if (ofs <= 0)   // int overflow
                    ofs = maxOfs;
            }
            if (ofs > maxOfs)
                ofs = maxOfs;

            // Make offsets relative to b
            int tmp = lastOfs;
            lastOfs = hint - ofs;
            ofs = hint - tmp;
        } else { // a[b + hint] <= key // run2的第一个值key,在Gallop 右边
            // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs]
            int maxOfs = len - hint;
            while (ofs < maxOfs && key.compareTo(a[base + hint + ofs]) >= 0) {
                lastOfs = ofs;
                ofs = (ofs << 1) + 1;
                if (ofs <= 0)   // int overflow
                    ofs = maxOfs;
            }
            if (ofs > maxOfs)
                ofs = maxOfs;

            // Make offsets relative to b
            lastOfs += hint;
            ofs += hint;
        }
        assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;

        /*
         * Now a[b + lastOfs] <= key < a[b + ofs], so key belongs somewhere to
         * the right of lastOfs but no farther right than ofs.  Do a binary
         * search, with invariant a[b + lastOfs - 1] <= key < a[b + ofs].
         */
        lastOfs++;
        while (lastOfs < ofs) {
            int m = lastOfs + ((ofs - lastOfs) >>> 1);

            if (key.compareTo(a[base + m]) < 0)
                ofs = m;          // key < a[b + m]
            else
                lastOfs = m + 1;  // a[b + m] <= key
        }
        assert lastOfs == ofs;    // so a[b + ofs - 1] <= key < a[b + ofs]
        return ofs;
    }

    private static final int  MIN_GALLOP = 7;

    // len1 小,把run1 复制到临时array,合并
    private void mergeLo(int base1, int len1, int base2, int len2) {
        assert len1 > 0 && len2 > 0 && base1 + len1 == base2;

        // Copy first run into temp array
        Object[] a = this.a; // For performance
        Object[] tmp = ensureCapacity(len1);

        int cursor1 = tmpBase; // Indexes into tmp array
        int cursor2 = base2;   // Indexes int a
        int dest = base1;      // Indexes int a
        System.arraycopy(a, base1, tmp, cursor1, len1);

        // Move first element of second run and deal with degenerate cases
        a[dest++] = a[cursor2++];
        if (--len2 == 0) {
            System.arraycopy(tmp, cursor1, a, dest, len1);
            return;
        }
        if (len1 == 1) {
            System.arraycopy(a, cursor2, a, dest, len2);
            a[dest + len2] = tmp[cursor1]; // Last elt of run 1 to end of merge
            return;
        }

        int minGallop = this.minGallop;  // Use local variable for performance
    outer:
        while (true) {
            int count1 = 0; // Number of times in a row that first run won
            int count2 = 0; // Number of times in a row that second run won

            /*
             * Do the straightforward thing until (if ever) one run starts
             * winning consistently.
             */
            do {
                assert len1 > 1 && len2 > 0;
                if (((Comparable) a[cursor2]).compareTo(tmp[cursor1]) < 0) {
                    a[dest++] = a[cursor2++];
                    count2++;
                    count1 = 0;
                    if (--len2 == 0)
                        break outer;
                } else {
                    a[dest++] = tmp[cursor1++];
                    count1++;
                    count2 = 0;
                    if (--len1 == 1)
                        break outer;
                }
            } while ((count1 | count2) < minGallop);

            /*
             * One run is winning so consistently that galloping may be a
             * huge win. So try that, and continue galloping until (if ever)
             * neither run appears to be winning consistently anymore.
             */
            do {
                assert len1 > 1 && len2 > 0;
                count1 = gallopRight((Comparable) a[cursor2], tmp, cursor1, len1, 0);
                if (count1 != 0) {
                    System.arraycopy(tmp, cursor1, a, dest, count1);
                    dest += count1;
                    cursor1 += count1;
                    len1 -= count1;
                    if (len1 <= 1)  // len1 == 1 || len1 == 0
                        break outer;
                }
                a[dest++] = a[cursor2++];
                if (--len2 == 0)
                    break outer;

                count2 = gallopLeft((Comparable) tmp[cursor1], a, cursor2, len2, 0);
                if (count2 != 0) {
                    System.arraycopy(a, cursor2, a, dest, count2);
                    dest += count2;
                    cursor2 += count2;
                    len2 -= count2;
                    if (len2 == 0)
                        break outer;
                }
                a[dest++] = tmp[cursor1++];
                if (--len1 == 1)
                    break outer;
                minGallop--;
            } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
            if (minGallop < 0)
                minGallop = 0;
            minGallop += 2;  // Penalize for leaving gallop mode
        }  // End of "outer" loop
        this.minGallop = minGallop < 1 ? 1 : minGallop;  // Write back to field

        if (len1 == 1) {
            assert len2 > 0;
            System.arraycopy(a, cursor2, a, dest, len2);
            a[dest + len2] = tmp[cursor1]; //  Last elt of run 1 to end of merge
        } else if (len1 == 0) {
            throw new IllegalArgumentException(
                "Comparison method violates its general contract!");
        } else {
            assert len2 == 0;
            assert len1 > 1;
            System.arraycopy(tmp, cursor1, a, dest, len1);
        }
    }

    // base1 = 第一段的开始位置,len1 = 第一段的长度,base2 = 第二段的开始位置,len2 = 第二段的长度
    // len2 小,把run2 复制到临时array,合并
    private void mergeHi(int base1, int len1, int base2, int len2) {
        assert len1 > 0 && len2 > 0 && base1 + len1 == base2;

        // Copy second run into temp array
        Object[] a = this.a; // For performance
        Object[] tmp = ensureCapacity(len2);
        int tmpBase = this.tmpBase;
        // 存放run2
        System.arraycopy(a, base2, tmp, tmpBase, len2);

        // run1 游标
        int cursor1 = base1 + len1 - 1;  // Indexes into a
        // 复制了 run2 游标
        int cursor2 = tmpBase + len2 - 1; // Indexes into tmp array
        // run2 游标,run2结束位置
        int dest = base2 + len2 - 1;     // Indexes into a

        // Move last element of first run and deal with degenerate cases
        // cursor1是run1的结束位置,dest是run2的结束位置,run1的结束位置的值一定大于run2
        a[dest--] = a[cursor1--];
        // len1 = 1,run2都放到dest--前面
        if (--len1 == 0) {
            System.arraycopy(tmp, tmpBase, a, dest - (len2 - 1), len2);
            return;
        }
        // run2的长度为1,那就把他放入到第一段的前面
        if (len2 == 1) {
            dest -= len1;
            cursor1 -= len1;
            System.arraycopy(a, cursor1 + 1, a, dest + 1, len1);
            a[dest] = tmp[cursor2];
            return;
        }

        // minGallop = 7,是默认值
        int minGallop = this.minGallop;  // Use local variable for performance
    outer:
        while (true) {
            // count1和count2记录run1中连续比run2大的
            int count1 = 0; // Number of times in a row that first run won
            int count2 = 0; // Number of times in a row that second run won

            /*
             * Do the straightforward thing until (if ever) one run
             * appears to win consistently.
             */
            do {
                assert len1 > 0 && len2 > 1;
                //比较run1的最后一个元素和run2的最后一个元素的大小
                if (((Comparable) tmp[cursor2]).compareTo(a[cursor1]) < 0) {
                    a[dest--] = a[cursor1--];
                    count1++;
                    count2 = 0;
                    if (--len1 == 0)
                        break outer;
                } else {
                    a[dest--] = tmp[cursor2--];
                    count2++;
                    count1 = 0;
                    if (--len2 == 1)
                        break outer;
                }
            } while ((count1 | count2) < minGallop);

            /*
             * One run is winning so consistently that galloping may be a
             * huge win. So try that, and continue galloping until (if ever)
             * neither run appears to be winning consistently anymore.
             */
            do {
                assert len1 > 0 && len2 > 1;
                count1 = len1 - gallopRight((Comparable) tmp[cursor2], a, base1, len1, len1 - 1);
                if (count1 != 0) {
                    dest -= count1;
                    cursor1 -= count1;
                    len1 -= count1;
                    System.arraycopy(a, cursor1 + 1, a, dest + 1, count1);
                    if (len1 == 0)
                        break outer;
                }
                // 每次这个合并完之后,重新去执行gallopRight或者gallopLeft方法,重新把不用合并的剔除掉
                a[dest--] = tmp[cursor2--];
                if (--len2 == 1)
                    break outer;

                count2 = len2 - gallopLeft((Comparable) a[cursor1], tmp, tmpBase, len2, len2 - 1);
                if (count2 != 0) {
                    dest -= count2;
                    cursor2 -= count2;
                    len2 -= count2;
                    System.arraycopy(tmp, cursor2 + 1, a, dest + 1, count2);
                    if (len2 <= 1)
                        break outer; // len2 == 1 || len2 == 0
                }
                a[dest--] = a[cursor1--];
                if (--len1 == 0)
                    break outer;
                minGallop--;
            } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
            if (minGallop < 0)
                minGallop = 0;
            minGallop += 2;  // Penalize for leaving gallop mode
        }  // End of "outer" loop
        this.minGallop = minGallop < 1 ? 1 : minGallop;  // Write back to field

        if (len2 == 1) {
            assert len1 > 0;
            dest -= len1;
            cursor1 -= len1;
            System.arraycopy(a, cursor1 + 1, a, dest + 1, len1);
            a[dest] = tmp[cursor2];  // Move first elt of run2 to front of merge
        } else if (len2 == 0) {
            throw new IllegalArgumentException(
                "Comparison method violates its general contract!");
        } else {
            assert len1 == 0;
            assert len2 > 0;
            System.arraycopy(tmp, tmpBase, a, dest - (len2 - 1), len2);
        }
    }
}

在归并算法中合并是两两分别合并,第一个和第二个合并,第三个和第四个合并,然后再合并这两个已经合并的序列。但是在Timsort中,合并是连续的,每次计算出了一个run之后都有可能导致一次合并,这样的合并顺序能够在合并的同时保证算法的稳定性。

Timsort 称这些已经排好序的数据块为 run。在排序时,Timsort迭代数据元素,将其放到不同的 run 里,同时针对这些 run ,按规则进行合并至只剩一个,则这个仅剩的 run 即为排好序的结果。

合并的平衡性就是为了让合并的两个数组的大小尽量接近,提高合并的效率。所以在合并的过程中需要尽量保留这些run用于发现后来的模式,但是我们又想尽量快的合并内存层级比较高的run,并且栈的空间是有限的,不能浪费太多的栈空间。通过以上的两个限制,可以将整个栈从底部到顶部的run的大小变成严格递减的,并且收敛速度和斐波那契数列一样,这样就可以应用斐波那契数列和的公式根据数组的长度计算出需要的栈的大小,一定是比𝑙𝑜𝑔1.618𝑁要小的,其中N是数组的长度。

在最理想的情况下,这个栈从底部到顶部的数字应该是128、64、32、16、8、4、2、2,这样从栈顶合并到栈底,每次合并的两个run的长度都是相等的,都是完美的合并。

Timsort 合并2个相邻的 run 需要临时存储空闲,临时存储空间的大小是2个 run 中较小的 run 的大小。Timsort算法先将较小的 run 复制到这个临时存储空间,然后用原先存储这2个 run 的空间来存储合并后的 run。

gallopRight 和 gallopLeft之后需要合并的部分

加速合并

在归并排序算法中合并两个数组就是一一比较每个元素,把较小的放到相应的位置,然后比较下一个,这样有一个缺点就是如果A中如果有大量的元素A[i…j]是小于B中某一个元素B[k]的,程序仍然会持续的比较A[i…j]中的每一个元素和B[k],增加合并过程中的时间消耗。

为了优化合并的过程,Tim设定了一个阈值MIN_GALLOP,如果A中连续MIN_GALLOP个元素比B中某一个元素要小,那么就进入GALLOP模式,反之亦然。默认的MIN_GALLOP值是7。

在GALLOP模式中,首先通过二分搜索找到A[0]在B中的位置i0,把B中i0之前的元素直接放入合并的空间中,然后再在A中找到B[i0]所在的位置j0,把A中j0之前的元素直接放入合并空间中,如此循环直至在A和B中每次找到的新的位置和原位置的差值是小于MIN_GALLOP的,这才停止然后继续进行一对一的比较。

Python

版本 3.10.2

内部实现机制为:Timsort、插入排序

Java是参考的Python的,TimSort 原理一样

Objects/listobject.c

int
PyList_Sort(PyObject *v)
{
    if (v == NULL || !PyList_Check(v)) {
        PyErr_BadInternalCall();
        return -1;
    }
    v = list_sort_impl((PyListObject *)v, NULL, 0);
    if (v == NULL)
        return -1;
    Py_DECREF(v);
    return 0;
}

list_sort_impl(PyListObject *self, PyObject *keyfunc, int reverse)
/*[clinic end generated code: output=57b9f9c5e23fbe42 input=cb56cd179a713060]*/
{
    MergeState ms;
    Py_ssize_t nremaining;
    Py_ssize_t minrun;
    sortslice lo;
    Py_ssize_t saved_ob_size, saved_allocated;
    PyObject **saved_ob_item;
    PyObject **final_ob_item;
    PyObject *result = NULL;            /* guilty until proved innocent */
    Py_ssize_t i;
    PyObject **keys;

    ...


    /* The pre-sort check: here's where we decide which compare function to use.
     * How much optimization is safe? We test for homogeneity with respect to
     * several properties that are expensive to check at compare-time, and
     * set ms appropriately. */
    ...

    merge_init(&ms, saved_ob_size, keys != NULL);

    nremaining = saved_ob_size;
    if (nremaining < 2)
        goto succeed;

    /* Reverse sort stability achieved by initially reversing the list,
    applying a stable forward sort, then reversing the final result. */
    if (reverse) {
        if (keys != NULL)
            reverse_slice(&keys[0], &keys[saved_ob_size]);
        reverse_slice(&saved_ob_item[0], &saved_ob_item[saved_ob_size]);
    }

    /* March over the array once, left to right, finding natural runs,
     * and extending short natural runs to minrun elements.
     */
    // TimSort 计算 minRun的值,最小分片长度 32 <= k <= 64
    minrun = merge_compute_minrun(nremaining);
    do {
        int descending;
        Py_ssize_t n;

        /* Identify next run. */
        n = count_run(&ms, lo.keys, lo.keys + nremaining, &descending);
        if (n < 0)
            goto fail;
        if (descending)
            reverse_sortslice(&lo, n);
        /* If short, extend to min(minrun, nremaining). */
        // 原始的run小于minrun的长度,用二分插入排序扩充run,直到达到条件
        if (n < minrun) {
            const Py_ssize_t force = nremaining <= minrun ?
                              nremaining : minrun;
            if (binarysort(&ms, lo, lo.keys + force, lo.keys + n) < 0)
                goto fail;
            n = force;
        }
        /* Push run onto pending-runs stack, and maybe merge. */
        assert(ms.n < MAX_MERGE_PENDING);
        // 入栈
        ms.pending[ms.n].base = lo;
        ms.pending[ms.n].len = n;
        ++ms.n;
        if (merge_collapse(&ms) < 0)
            goto fail;
        /* Advance to find next run. */
        sortslice_advance(&lo, n);
        nremaining -= n;
    } while (nremaining);

    if (merge_force_collapse(&ms) < 0)
        goto fail;
    assert(ms.n == 1);
    assert(keys == NULL
           ? ms.pending[0].base.keys == saved_ob_item
           : ms.pending[0].base.keys == &keys[0]);
    assert(ms.pending[0].len == saved_ob_size);
    lo = ms.pending[0].base;

succeed:
    result = Py_None;
fail:
    ...
    result = NULL;
}
#undef IFLT
#undef ISLT

/* Compute a good value for the minimum run length; natural runs shorter
 * than this are boosted artificially via binary insertion.
 *
 * If n < 64, return n (it's too small to bother with fancy stuff).
 * Else if n is an exact power of 2, return 32.
 * Else return an int k, 32 <= k <= 64, such that n/k is close to, but
 * strictly less than, an exact power of 2.
 *
 * See listsort.txt for more info.
 */
static Py_ssize_t
merge_compute_minrun(Py_ssize_t n)
{
    Py_ssize_t r = 0;           /* becomes 1 if any 1 bits are shifted off */

    assert(n >= 0);
    while (n >= 64) {
        r |= n & 1;
        n >>= 1;
    }
    return n + r;
}

/* Examine the stack of runs waiting to be merged, merging adjacent runs
 * until the stack invariants are re-established:
 *
 * 1. len[-3] > len[-2] + len[-1]
 * 2. len[-2] > len[-1]
 *
 * See listsort.txt for more info.
 *
 * Returns 0 on success, -1 on error.
 */
static int
merge_collapse(MergeState *ms)
{
    struct s_slice *p = ms->pending;

    assert(ms);
    while (ms->n > 1) {
        Py_ssize_t n = ms->n - 2;
        if ((n > 0 && p[n-1].len <= p[n].len + p[n+1].len) ||
            (n > 1 && p[n-2].len <= p[n-1].len + p[n].len)) {
            if (p[n-1].len < p[n+1].len)
                --n;
            if (merge_at(ms, n) < 0)
                return -1;
        }
        else if (p[n].len <= p[n+1].len) {
            if (merge_at(ms, n) < 0)
                return -1;
        }
        else
            break;
    }
    return 0;
}

merge_at(MergeState *ms, Py_ssize_t i)
{
    sortslice ssa, ssb;
    Py_ssize_t na, nb;
    Py_ssize_t k;
    ssa = ms->pending[i].base;
    na = ms->pending[i].len;
    ssb = ms->pending[i+1].base;
    nb = ms->pending[i+1].len;
    assert(na > 0 && nb > 0);
    assert(ssa.keys + na == ssb.keys);

    /* Record the length of the combined runs; if i is the 3rd-last
     * run now, also slide over the last run (which isn't involved
     * in this merge).  The current run i+1 goes away in any case.
     */
    ms->pending[i].len = na + nb;
    if (i == ms->n - 3)
        ms->pending[i+1] = ms->pending[i+2];
    --ms->n;

    /* Where does b start in a?  Elements in a before that can be
     * ignored (already in place).
     */
    k = gallop_right(ms, *ssb.keys, ssa.keys, na, 0);
    if (k < 0)
        return -1;
    sortslice_advance(&ssa, k);
    na -= k;
    if (na == 0)
        return 0;

    /* Where does a end in b?  Elements in b after that can be
     * ignored (already in place).
     */
    nb = gallop_left(ms, ssa.keys[na-1], ssb.keys, nb, nb-1);
    if (nb <= 0)
        return nb;

    /* Merge what remains of the runs, using a temp array with
     * min(na, nb) elements.
     */
    if (na <= nb)
        return merge_lo(ms, ssa, na, ssb, nb);
    else
        return merge_hi(ms, ssa, na, ssb, nb);
}

Golang

sort/sort.go

内部实现机制为:快速排序、堆排序、希尔排序、插入排序

算法稳定性:不稳定

  • slices <= 12,使用 gap = 6 的希尔排序
  • slices > 12,使用快速排序,深度为0的时候使用堆排序
Len() int: 返回传入数据的总数
Less(i, j int) bool: 返回数组中下标为i的数据是否小于下标为j的数据
Swap(i, j int): 表示执行交换数组中下标为i的数据和下标为j的数据

调用sort包的排序,需要实现排序接口

type Interface interface {
	// Len is the number of elements in the collection.
	Len() int
	// Less reports whether the element with
	// index i should sort before the element with index j.
	Less(i, j int) bool
	// Swap swaps the elements with indexes i and j.
	Swap(i, j int)
}

Sort方法

func Sort(data Interface) {
	n := data.Len()
	quickSort(data, 0, n, maxDepth(n))
}

func quickSort(data Interface, a, b, maxDepth int) {
	for b-a > 12 { // Use ShellSort for slices <= 12 elements
		if maxDepth == 0 {
			heapSort(data, a, b)
			return
		}
		maxDepth--
		mlo, mhi := doPivot(data, a, b)
		// Avoiding recursion on the larger subproblem guarantees
		// a stack depth of at most lg(b-a).
		if mlo-a < b-mhi {
			quickSort(data, a, mlo, maxDepth)
			a = mhi // i.e., quickSort(data, mhi, b)
		} else {
			quickSort(data, mhi, b, maxDepth)
			b = mlo // i.e., quickSort(data, a, mlo)
		}
	}
	if b-a > 1 {
		// Do ShellSort pass with gap 6
		// It could be written in this simplified form cause b-a <= 12
		for i := a + 6; i < b; i++ {
			if data.Less(i, i-6) {
				data.Swap(i, i-6)
			}
		}
		insertionSort(data, a, b)
	}
}

// 返回值 2*ceil(lg(n+1)),20 返回 10
// maxDepth returns a threshold at which quicksort should switch
// to heapsort. It returns 2*ceil(lg(n+1)).
func maxDepth(n int) int {
	var depth int
	for i := n; i > 0; i >>= 1 {
		depth++
	}
	return depth * 2
}

// 建立一个大顶堆,将根节点交换到堆数据的最后,用剩下的元素继续建立大顶堆,然后不断重复上述步骤,直至最后一个元素
func heapSort(data Interface, a, b int) {
	first := a
	lo := 0
	hi := b - a

	// Build heap with greatest element at top.
	for i := (hi - 1) / 2; i >= 0; i-- {
		siftDown(data, i, hi, first)
	}

	// Pop elements, largest first, into end of data.
	for i := hi - 1; i >= 0; i-- {
		data.Swap(first, first+i)
		siftDown(data, lo, i, first)
	}
}

func siftDown(data Interface, lo, hi, first int) {
	root := lo
	for {
		child := 2*root + 1
		if child >= hi {
			break
		}
        // 保证交换的是最大的子节点
		if child+1 < hi && data.Less(first+child, first+child+1) {
			child++
		}
        // 父节点已经是最大了,不需要交换
		if !data.Less(first+root, first+child) {
			return
		}
        // 交换父子节点
		data.Swap(first+root, first+child)
		root = child
	}
}

func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
    // 首先用位运算的方式求中间点,防止溢出
	m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.
	if hi-lo > 40 {
        //  多数取中
		// Tukey's ``Ninther,'' median of three medians of three.
		s := (hi - lo) / 8
		medianOfThree(data, lo, lo+s, lo+2*s)
		medianOfThree(data, m, m-s, m+s)
		medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
	}
	medianOfThree(data, lo, m, hi-1)

	// Invariants are:
	//	data[lo] = pivot (set up by ChoosePivot)
	//	data[lo < i < a] < pivot
	//	data[a <= i < b] <= pivot
	//	data[b <= i < c] unexamined
	//	data[c <= i < hi-1] > pivot
	//	data[hi-1] >= pivot
	pivot := lo
	a, c := lo+1, hi-1

	for ; a < c && data.Less(a, pivot); a++ {
	}
	b := a
	for {
		for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
		}
		for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
		}
		if b >= c {
			break
		}
		// data[b] > pivot; data[c-1] <= pivot
		data.Swap(b, c-1)
		b++
		c--
	}
	// If hi-c<3 then there are duplicates (by property of median of nine).
	// Let's be a bit more conservative, and set border to 5.
	protect := hi-c < 5
	if !protect && hi-c < (hi-lo)/4 {
		// Lets test some points for equality to pivot
		dups := 0
		if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
			data.Swap(c, hi-1)
			c++
			dups++
		}
		if !data.Less(b-1, pivot) { // data[b-1] = pivot
			b--
			dups++
		}
		// m-lo = (hi-lo)/2 > 6
		// b-lo > (hi-lo)*3/4-1 > 8
		// ==> m < b ==> data[m] <= pivot
		if !data.Less(m, pivot) { // data[m] = pivot
			data.Swap(m, b-1)
			b--
			dups++
		}
		// if at least 2 points are equal to pivot, assume skewed distribution
		protect = dups > 1
	}
	if protect {
		// Protect against a lot of duplicates
		// Add invariant:
		//	data[a <= i < b] unexamined
		//	data[b <= i < c] = pivot
		for {
			for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
			}
			for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
			}
			if a >= b {
				break
			}
			// data[a] == pivot; data[b-1] < pivot
			data.Swap(a, b-1)
			a++
			b--
		}
	}
	// Swap pivot into middle
	data.Swap(pivot, b-1)
	return b - 1, c
}

// Insertion sort 插入排序
func insertionSort(data Interface, a, b int) {
	for i := a + 1; i < b; i++ {
		for j := i; j > a && data.Less(j, j-1); j-- {
			data.Swap(j, j-1)
		}
	}
}

资料

算法4