玄学优化一个稳定排序算法

前一阵子(还挺前的)正好在忙数据结构的课程设计,大体是要求做一个航班管理系统。程序主体就是简单堆几个高效数据结构,再糊上一个RESTful API,没什么好谈的。不过在优化其中的排序算法时倒是学到了挺多。虽然说本质还是缝合若干优秀算法,但刚好最近也很久没更新博客了,所以干脆写一篇博客简述当时的思路吧。优化思路本身都是拾人牙慧,有错漏还请指出。

整体思路

首先要保证的是排序算法必须稳定。由于题目要求多条件排序,因此重复多次排序就要求排序算法的稳定性了。虽然多条件排序可以直接通过修改比较器的方法来实现,而Java也提供了Comparator::then来实现,不过这样就没数据结构内味了。而且非稳定的排序算法,IntroSort就已经很优秀了。

基本的优化思路还是以IntroSort为范本,对快速排序进行优化。快速排序本身效率很高,但是大致有如下问题:

  1. 在 right – left 较小时,递归调用较多,操作效率低
  2. 枢轴选择不当时,最差会退化至 O(n^2)
  3. 处理大量重复数据时,枢轴选择容易不当
  4. 递归操作本身效率低下

而按照IntroSort思路,可以优化1、2两点,即:

  1. 在 right – left 较小时,使用插入排序
  2. 在递归层数过深(快排退化)时,使用归并排序

其他两点可以在单独针对快速排序时优化。因此,优化的主体是三个排序算法:插入排序、归并排序、快速排序。

成对插入排序

插入排序本身就是稳定的,因此无需对插入排序进行稳定化处理。对插入排序的优化主要借鉴自Java标准库DualPivotQuicksort::sort。简而言之,就是对有序的两个元素同时插入,由此就可以让两趟遍历减少为一趟。比如对如下情况,选择的一对元素就是15和6。

先对较大的元素15进行插入之后,就可以从当前位置继续查找较小值6插入的位置。

插入6之后,完成一趟排序。

使用Java代码实现的关键部分大致如下(代码来自JDK):

// 检测开头是否有升序段
do {
    if (cur >= right)
        return;
} while (cmp.compare(arr[++cur], arr[cur - 1]) >= 0);

// 成对插入排序主逻辑
for (var i = cur; cur < right; i = ++cur) {
    cur++;
    S fst = arr[i], snd = arr[cur];
    if (cmp.compare(fst, snd) < 0) {
        snd = fst;
        fst = arr[cur];
    }
    // 此时 i, cur 分别指向待插入元素,并且 fst > snd
    // 先插入较大元素 fst
    while (--i >= left && cmp.compare(fst, arr[i]) < 0) {
        arr[i + 2] = arr[i]; // 偏移为 2 保证元素位置比 fst, snd 都大
    }
    arr[++i + 1] = fst;
    // 再插入较小元素 snd
    while (--i >= left && cmp.compare(snd, arr[i]) < 0) {
        arr[i + 1] = arr[i];
    }
    arr[i + 1] = snd;
}

// 处理多余元素
// ...

不难看出,由于将相邻两趟遍历减少为一次遍历,因此比较次数应该大致为原始选择排序的一半。以下为比较次数的对比测试,选取了最差情况倒序和随机两种情况进行比较。

完整的代码位于:net.kaaass.kflight.algorithm.sort.BiInsertSort::sort

自适应归并排序

这一部分借鉴自C++ STL的std::stable_sort。自适应听起来高大上,实际上就是用到了数据本身的有序部分。算法主要针对归并排序的两个点进行优化,一是归并排序在数据较小时频繁合并、且没有用到数据本身的有序,二是归并排序的递归调用。算法的大致思想是,首先取数据中有序的一个子段(称为run),此后加入栈中,并按照一定策略进行合并。使用栈并且规定合并策略是为了保证合并时两个数组的长度差距可控,否则算法最差将会退化为 O(n^2)(即每一次合并其中一个数组都只有固定个数字,比如一个)。

首先从右向左侧取一段run(保证大小至少为4,可以用插入排序创造),之后将这个run存储在run栈之内。由于栈先进先出的性质,因此栈顶的栈是数组中最靠左的run。而数组中每一个run左侧的run,就是run栈中更靠栈顶的run。入栈之后,就可以开始run的合并了。

合并算法关键在于,每次合并操作都需要保持性质:每个升序段至少比左侧的(即下一个升序段)大2倍。若维持这个性质,那么最短的情况下,栈中run长度应该是按2的幂次递增的(比如栈顶run长度为2^2,则下一段长度为2^3,以此类推)。因此在一般情况下,对于新入栈的不太长的run,我们可以直接合并栈顶两个升序段。

比如下图,栈中原本有两端run(即mid … tail和tail … 末尾)是保持性质的(一段为4,一段为8)。

而对于新入栈的run(即head … mid),它的长度较短,因此栈顶两端run(即head … mid和mid … tail)的长度差距不大,直接合并即可。

如果新的run实在是太长,超过了他右侧run的两倍大小,则优先合并右侧的两端,直到和新的run长度差距在两倍之内。如下图,当前run(head … mid)显然远超过下一段(mid … tail),因此先合并后两段(mid … tail和tail … tail2)。注意,这里为了便于演示取消了run最小长度为4的限制。

合并之后发现,当前run的长度依旧大于下一段长度的两倍,因此继续合并后两段(mid … tail和tail … tail2)。

此时,临近两端长度接近,因此合并这两段即可。

由此,在理想情况下(即数组足够长),该算法保证合并时两个有序数组的长度差小于两倍,由此保证了算法不会退化。此外,由于栈中最小情况run长度都是以2的幂次增长,并且合并的复杂度为O(m+n),因此算法整体复杂度还是O(n \log n)

算法关键部分的Java实现如下。注意这里并不是直接判断比右侧大两倍的,而是以跨一个run进行判断(可以节省一个×2)。

/*
 * 算法从右到左查找升序段,并按照策略合并升序段
 * 升序段维持性质:每个升序段至少比左侧的(即下一个升序段)大2倍
 * 通过维持这个性质,尽可能保证两个升序段在合并时长度接近(不小于一半)
 * 在合并时,若当前段过大,则优先合并之后两段
 */
int head = right;
// 存放升序段(称run)的栈
var runStack = new int[32]; // 最多 2^32 - 1 元素
int runSize = 0;
do {
    int mid = head;
    head--;
    // 寻找升序段 head ... mid
    while (head > left) {
        if (cmp.compare(arr[head - 1], arr[head]) > 0) {
            if (mid - head < 4) // 升序段太短
                insertFirst(arr, head - 1, mid, cmp); // 从左到右进行插入
            else
                break;
        }
        head--;
    }
    // 此时新升序段未入栈。如果还有其他段,从左到右(弹栈)检查以维持性质
    // 最大同时弹出3段,指针分布:head ... mid ... tail ... nextTail
    while (runSize >= 1) {
        // 取最近一段:mid ... tail
        int tail = runStack[runSize - 1];
        // 如果有三段以上
        while (runSize >= 2) {
            // 取之前一段:tail ... nextTail
            int nextTail = runStack[runSize - 2];
            // 如果当前升序段比之前段要短,则可以直接处理前2段
            if ((mid - head) <= (nextTail - tail))
                break;
            // 如果当前升序段比之前段要长,则需要防止出现合并的两端长度差距过大
            // 因此此时合并 mid ... tail ... nextTail
            // 根据性质合并之后的新段小于更之前的段,继续循环处理
            merge(arr, mid, tail, nextTail, buf, cmp);
            // 弹栈并向后处理
            tail = nextTail;
            runSize--;
        }
        // 检查是否保持性质,或在 head <= left 时循环合并所有分段
        if (head > left && (mid - head) <= (tail - mid) / 2)
            break;
        // 如果不符合性质,合并相邻两个升序段
        merge(arr, head, mid, tail, buf, cmp);
        // 遍历前一分段
        mid = tail;
        runSize--;
    }
    // 增加查找到的升序段
    runStack[runSize] = mid;
    runSize++;
}
while (head > left);

完整的代码可以参考:net.kaaass.kflight.algorithm.sort.AdaptiveMergeSort::sort

稳定化快排、三者取中、双枢轴优化

快排本身并不是稳定的,不过好在稳定快排并不难,引入额外空间就可以实现了。另外,由于此场景下是针对对象排序,因此还可以使用null来进行标记,以减少比较次数。简单实现可以参考:net.kaaass.kflight.algorithm.sort.StableQuickSort::partition

三者取中很好理解,就是在区间开头、中间、结尾取三个数,并使用中间的一个作为partition的依据。

双枢轴优化并不是指的JDK当前采用的双枢轴快速排序,而是只是用一个排序依据,大致相当于双枢轴快速排序的一个退化情况。由于程序场景下会遇到一定量的重复数据,因此三者取中优化并增添一个“== 排序依据”的分段效果要更好。大致分段情况如下:

left         lfPivot         rtPivot        right
| ... < key ... | ... = key ... | ... > key ... |

partition大体逻辑的Java实现如下。注意这里采用了真值表优化三者取中的运算。

// 处理分段
lfPivot = left;
rtPivot = right;
if (right - left >= 2) {
/*
 * 选择 Key 算法的推导如下,假设三个数 a, b, c,首先计算:
 *   c1 = a < b
 *   c2 = b < c
 *   c3 = a < c
 * 之后,可以列出下标:
 *   ret | c1 | c2 | c3
 *    b  | 0  | 0  | 0
 *    _  | 0  | 0  | 1
 *    c  | 0  | 1  | 0
 *    a  | 0  | 1  | 1
 *    a  | 1  | 0  | 0
 *    c  | 1  | 0  | 1
 *    _  | 1  | 1  | 0
 *    b  | 1  | 1  | 1
 * 观察表格,可以发现表格上下对称,因此可以用异或进行判断。
 * 由于 compare 的代价较高,并且一般情况下数据有一定顺序,
 * 因此优先选择位置靠中的 b,若非 b 再计算 c3。
 */
int mid = (left >> 1) + (right >> 1);
S aLf = arr[left], key = arr[mid], aRt = arr[right - 1];
boolean cmp1 = cmpr.compare(aLf, key) < 0,
        cmp2 = cmpr.compare(key, aRt) < 0;
if (cmp1 ^ cmp2) {
    boolean cmp3 = cmpr.compare(aLf, aRt) < 0;
    if (cmp2 ^ cmp3)
        key = aRt;
    else
        key = aLf;
}
/*
 * 排序采用了双枢轴优化。不过此处的优化与 DualPivot 不同,
 * 两个枢轴之间仅仅存放与 key 相同的数据。由于为了维持排序
 * 稳定使用了额外空间 buf,因此可以利用该空间减少赋值操作。
 * 为了减少比较次数,采用了在 buf 不同位置存放,并且使用
 * null 做标记。
 *
 * left         lfPivot         rtPivot        right
 * | ... < key ... | ... = key ... | ... > key ... |
 */
int lfCur = left, midCur, rtCur, cmp;
midCur = rtCur = right - 1;
for (var i = left; i < right; i++) {
    cmp = cmpr.compare(arr[i], key);
    if (cmp < 0) {
        buf[lfCur] = arr[i];
        lfCur++;
        arr[i] = null; // 作标记
    } else if (cmp == 0) {
        buf[midCur] = arr[i];
        midCur--;
        arr[i] = null; // 作标记
    }  // else: pass
}
for (var i = right - 1; i >= left; i--) {
    if (arr[i] != null) {
        if (rtCur != i)
            arr[rtCur] = arr[i];
        rtCur--;
    }
}
// 复制
if (lfCur - left >= 0)
    System.arraycopy(buf, left, arr, left, lfCur - left);
for (var i = 0; i < right - 1 - midCur; i++)
    arr[lfCur + i] = (S) buf[right - 1 - i];
// 分段
lfPivot = lfCur;
rtPivot = rtCur + 1;

最后为了减少递归次数,可以采用尾递归优化。实际上我横向对比了多种方式,包括使用栈去递归,但由于Java的语言特性都不及尾递归优化效果好。

完整的逻辑可以参考:net.kaaass.kflight.algorithm.sort.StableTriQuickSort::sort

混合排序算法

大致组合方式与整体思路中介绍的无异。值得一提的就是对递归过深,采用了2 \log n作为判定界限。基本上与std::sort实现的IntroSort保持一致。由于每一个组成的算法都是稳定的,因此最终的排序算法也是稳定的。

完整的逻辑可以参考:net.kaaass.kflight.algorithm.sort.StableHybridSort::sort

Benchmark

Benchmark的结果还是比较出人意料的,在大多项目上优化后的排序效果(即StableHybridSort)竟然超过了Arrays.sort(混合TimSort)。

====== Random ======
StableQuickSort avg: 62.574269 ms
StableTriQuickSort avg: 60.347577 ms
AdaptiveMergeSort avg: 53.475586 ms
StableHybridSort avg: 58.909563 ms
StableHybridSort.normal avg: 60.792440 ms
Java Arrays.sort avg: 62.191020 ms

====== NearlySorted ======
StableQuickSort avg: 354.940236 ms
StableTriQuickSort avg: 2.774454 ms
AdaptiveMergeSort avg: 2.387234 ms
StableHybridSort avg: 2.130851 ms
StableHybridSort.normal avg: 6.248042 ms
Java Arrays.sort avg: 0.841049 ms
BiInsertSort avg: 0.769681 ms

====== MostSame ======
StableQuickSort avg: 349.206805 ms
StableTriQuickSort avg: 8.321597 ms
AdaptiveMergeSort avg: 18.107755 ms
StableHybridSort avg: 9.566199 ms
StableHybridSort.normal avg: 28.683752 ms
Java Arrays.sort avg: 10.712380 ms

====== Reverse ======
StableQuickSort: Done in 3040.477061 ms.
StableTriQuickSort: Done in 4.314071 ms.
AdaptiveMergeSort: Done in 3.653426 ms.
StableHybridSort: Done in 5.883134 ms.
StableHybridSort.normal: Done in 10.668082 ms.
Java Arrays.sort: Done in 0.430858 ms.

Reference

  1. 可视化:https://github.com/hediet/vscode-debug-visualizer
分享到

KAAAsS

喜欢二次元的程序员,喜欢发发教程,或者偶尔开坑。(←然而并不打算填)

相关日志

  1. 没有图片
  2. 没有图片
  3. 没有图片
  4. 没有图片

评论

还没有评论。

在此评论中不能使用 HTML 标签。