Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

java8 TimSort源码解析 #1

Open
jacky1193610322 opened this issue Apr 18, 2018 · 0 comments
Open

java8 TimSort源码解析 #1

jacky1193610322 opened this issue Apr 18, 2018 · 0 comments

Comments

@jacky1193610322
Copy link
Owner

jacky1193610322 commented Apr 18, 2018

java8种Arrays.sort使用的是TimSort方法,其实从总体来看是我个人看法是归并排序的优化版,先不说作用,先从代码说起

static <T> void sort(T[] a, int lo, int hi, Comparator<? super T> c,
                         T[] work, int workBase, int workLen) {
        assert c != null && a != null && lo >= 0 && lo <= hi && hi <= a.length;

        int nRemaining  = hi - lo;
        /* 表示要排序的数量
         * 因为Hi表示的下标是不参与排序的,所以这里计算直接是hi - ho
         * 如果要排序的元素数量小于2,那就不需要排序,已经是有序了
         */
        if (nRemaining < 2)
            return;  // Arrays of size 0 and 1 are always sorted

        // If array is small, do a "mini-TimSort" with no merges
        /* 这里如果要排序的元素数量是小于MIN_MERGE 这里MIN_MERGE是32 但是这里
         * 的值应该只是一个经验值, 我现在还不知道作者为什么设置为这个值, 下文中暂且不考虑这个值的来源, 如果小于这个阈值,就会使用二分查找插入
         */
        if (nRemaining < MIN_MERGE) {
            int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
            binarySort(a, lo, hi, lo + initRunLen, c);
            return;
        }

        /**
         * March over the array once, left to right, finding natural runs,
         * extending short natural runs to minRun elements, and merging runs
         * to maintain stack invariant.
         */
        TimSort<T> ts = new TimSort<>(a, c, work, workBase, workLen);
        // 这个是计算最小run的大小 暂且不关注
        int minRun = minRunLength(nRemaining);
        do {
            // Identify next run
            // 计算run的长度, 函数的解析在下面, 先看这个函数的用处,就是计算
            // 数组中连续递增的区间长度,如果是递减,那么一样算,只不过会反转过来
            int runLen = countRunAndMakeAscending(a, lo, hi, c);

            // If run is short, extend to min(minRun, nRemaining)
            // 如果这个区间比较小,那么会将他后面的元素按照二分插入的方式,插入进去,凑成一个满足一定长度的递增区间(不是严格递增 可以中间相等,但是必须是递增趋势)
            if (runLen < minRun) {
                int force = nRemaining <= minRun ? nRemaining : minRun;
                binarySort(a, lo, lo + force, lo + runLen, c);
                runLen = force;
            }

            // Push run onto pending-run stack, and maybe merge
            // 这个是个栈,将这个递增区间的起始下标和长度压入栈中
            ts.pushRun(lo, runLen);
            // 查看要不要合并相邻的run  这个合并可以理解为准备合并两个
            // 有序的区间为一个大的区间
            ts.mergeCollapse();

            // Advance to find next run
            lo += runLen;
            nRemaining -= runLen;
        } while (nRemaining != 0);

        // Merge all remaining runs to complete sort
        assert lo == hi;
        ts.mergeForceCollapse();
        assert ts.stackSize == 1;
    }

二分查找参见 二分查找
python 版Timsort的各个设计的来源

/*
 * 这个函数的目的是为了查找连续递增(这个递增包括相等,但是总的趋势是增)
 * 或者是严格递减(也就是不能包含相等)这样做的原因是为了排序的稳定性
 * 如果是严格递减,那么就会反转过来
 */
private static <T> int countRunAndMakeAscending(T[] a, int lo, int hi,
                                                    Comparator<? super T> c) {
        assert lo < hi;
        int runHi = lo + 1;
        if (runHi == hi)
            return 1;

        // Find end of run, and reverse range if descending
        // 先比较前两个的元素的大小,来确定接下来的趋势
        if (c.compare(a[runHi++], a[lo]) < 0) { // Descending
            while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) < 0)
                runHi++;
            reverseRange(a, lo, runHi);
        } else {                              // Ascending
            while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) >= 0)
                runHi++;
        }

        return runHi - lo;
    }
/* 这个就是二分插入了,也就是已经有了一段有序的区间,还有剩下的部分不是有序的,那么
 * 循环将剩下的部分利用二分查找找到应该插入的位置
 */
private static <T> void binarySort(T[] a, int lo, int hi, int start,
                                       Comparator<? super T> c) {
        assert lo <= start && start <= hi;
        // 如果start等于最小的下标,那么第一个元素可以认为是有序的
        if (start == lo)
            start++;
        /* 开始循环 将[start, hi) 逐个插入排序, 
         * start代表不在有序区间的第一个元素下标
         * 也就是说[start, hi) 这个区间不是有序的
         */
        for ( ; start < hi; start++) {
            T pivot = a[start];

            // Set left (and right) to the index where a[start] (pivot) belongs
            int left = lo;
            int right = start;
            assert left <= right;
            /*
             * Invariants:
             *   pivot >= all in [lo, left).
             *   pivot <  all in [right, start).
             *   这个部分其实就是找到第一个大于privot的下标
             *   这个部分其实是有难度的,二分查找有很多变种
             *   比如 数据从小到大,查找第一个大于某个数key的下标位置
             *   查找第一个大于某个数key的下标位置
             *   下一篇会写
             */
            while (left < right) {
                int mid = (left + right) >>> 1;
                if (c.compare(pivot, a[mid]) < 0)
                    right = mid;
                else
                    left = mid + 1;
            }
            // 循环结束之后left == right 那么也就是要找的下标位置
            assert left == right;

            /*
             * The invariants still hold: pivot >= all in [lo, left) and
             * pivot < all in [left, start), so pivot belongs at left.  Note
             * that if there are elements equal to pivot, left points to the
             * first slot after them -- that's why this sort is stable.
             * Slide elements over to make room for pivot.
             */
            int n = start - left;  // The number of elements to move
            // Switch is just an optimization for arraycopy in default case
            switch (n) {
                case 2:  a[left + 2] = a[left + 1];
                case 1:  a[left + 1] = a[left];
                         break;
                // 这边就相当于left到start前一位的元素先后移动一位,空出来一个位置给privot, 就这样一直到全部插入完, 上面case 2 和case 1z只是简单版copy
                default: System.arraycopy(a, left, a, left + 1, n);
            }
            a[left] = pivot;
        }
    }
/* 这个是为了将两个run合并为一个run 但是要符合条件才会合并,
 *   1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
 *   2. runLen[i - 2] > runLen[i - 1]
 *   如果不满足上面两个条件就会合并
 *   上面两个条件比较抽象 
 *   可以想象为汉诺塔,底下大,上面慢慢变小, 但是有点不同的是,
 *   上面两个的run的长度之和小于底下的那个
 *   为什么这样限制呢,因为
 *   1. run的合并只能合并相邻的run 不能跨run合并,你可以尝试一下跨run之后代码会
 *   变得很复杂
 *   2. 尽量让run长度相差不大的情况下合并 为了减少比较的次数  假设run的长度依次为
 *   1000 1 1
 *   为了快速体现出来上面是极端例子,也就是说一个只有一个人的队伍插入到一个1000人排好序的队伍,它需要从头开始挨个比较,这个插入按照归并排序的merge方法插入
 *   上面最坏是1000 次比较,假设有1000个1 那么就是粗略的1000 * 1000
 *   但假设我们1000个1 合并成一个1000人拍好序的队伍,与另一个拍好序的队伍合并
 *   那么相当于最坏2000次比较就行了
 */
private void mergeCollapse() {
    while (stackSize > 1) {
        int n = stackSize - 2;
        if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
            if (runLen[n - 1] < runLen[n + 1])
                n--;
            mergeAt(n);
        } else if (runLen[n] <= runLen[n + 1]) {
            mergeAt(n);
        } else {
            break; // Invariant is established
        }
    }
}
private void mergeAt(int i) {
        assert stackSize >= 2;
        assert i >= 0;
        assert i == stackSize - 2 || i == stackSize - 3;

        int base1 = runBase[i];
        int len1 = runLen[i];
        int base2 = runBase[i + 1];
        int len2 = runLen[i + 1];
        assert len1 > 0 && len2 > 0;
        assert base1 + len1 == base2;

        /*
         * Record the length of the combined runs; if i is the 3rd-last
         * run now, also slide over the last run (which isn't involved
         * in this merge).  The current run (i+1) goes away in any case.
         */
        runLen[i] = len1 + len2;
        if (i == stackSize - 3) {
            runBase[i + 1] = runBase[i + 2];
            runLen[i + 1] = runLen[i + 2];
        }
        stackSize--;

        /*
         * Find where the first element of run2 goes in run1. Prior elements
         * in run1 can be ignored (because they're already in place).
         */
         /* 这个函数是为了找到a[base2] 在 a[base1, base1 + len1-1]中位置 
          * 第一个大于a[base2]的下标
          * 并且返回这个下标距离起始点的长度
          * 这个长度代表着,第二个run的第一个元素比他们都要大,
          * 那么这个区间就不需要排序了
          */
        int k = gallopRight(a[base2], a, base1, len1, 0, c);
        assert k >= 0;
        base1 += k;
        len1 -= k;
        if (len1 == 0)
            return;

        /*
         * Find where the last element of run1 goes in run2. Subsequent elements
         * in run2 can be ignored (because they're already in place).
         * 找到a[base1 + len1 - 1]在a[base2, base2 + len - 1 ]中的
         * 第一个大于等于它的下标并且返回这个下标距离终点的长度
         * 这个长度代表着,第1个run的最后一个元素比他们都要小,
         * 那么这个区间就不需要排序了 因为肯定是最大的
         * 为什么这个是大于等于 而gallopRight是大于 因为稳定排序的要求
         * 这个大于等于是因为,开始的这个元素是第一个run的,那么这个元素加入在第二
         * 个run有相等的元素,那么这个元素应该排在第一个run的前面才对,而
         * gallopRight 应该排在后面才能满足稳定排序的要求
         */
        len2 = gallopLeft(a[base1 + len1 - 1], a, base2, len2, len2 - 1, c);
        assert len2 >= 0;
        if (len2 == 0)
            return;

        // Merge remaining runs, using tmp array with min(len1, len2) elements
        if (len1 <= len2)
            mergeLo(base1, len1, base2, len2);
        else
            mergeHi(base1, len1, base2, len2);
    }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant