小雨兮兮 發表於 2022-3-6 16:07:00

go 中 sort 如何排序,源码解读

<ul>
<li>sort 包源码解读
<ul>
<li>前言</li>
<li>如何使用
<ul>
<li>基本数据类型切片的排序</li>
<li>自定义 Less 排序比较器</li>
<li>自定义数据结构的排序</li>
</ul>
</li>
<li>分析下源码
<ul>
<li>不稳定排序</li>
<li>稳定排序</li>
<li>查找</li>
<li>Interface</li>
</ul>
</li>
<li>总结</li>
<li>参考</li>
</ul>
</li>
</ul>

<h2 id="sort-包源码解读">sort 包源码解读</h2>
<h3 id="前言">前言</h3>
<p>我们的代码业务中很多地方需要我们自己进行排序操作,go 标准库中是提供了 sort 包是实现排序功能的,这里来看下生产级别的排序功能是如何实现的。</p>
<p><code>go version go1.16.13 darwin/amd64</code></p>
<h3 id="如何使用">如何使用</h3>
<p>先来看下 sort 提供的主要功能</p>
<ul>
<li>
<p>对基本数据类型切片的排序支持</p>
</li>
<li>
<p>自定义 Less 排序比较器</p>
</li>
<li>
<p>自定义数据结构的排序</p>
</li>
<li>
<p>判断基本数据类型切片是否已经排好序</p>
</li>
<li>
<p>基本数据元素查找</p>
</li>
</ul>
<h4 id="基本数据类型切片的排序">基本数据类型切片的排序</h4>
<p>sort 包中已经实现了对 <code>[]int, []float, []string</code> 这几种类型的排序</p>
<pre><code class="language-go">func TestSort(t *testing.T) {
        s := []int{5, 2, 6, 3, 1, 4}
        fmt.Println("是否排好序了", sort.IntsAreSorted(s))
        sort.Ints(s)
        // 正序
        fmt.Println(s)
        // 倒序
        sort.Sort(sort.Reverse(sort.IntSlice(s)))
        fmt.Println(s)
        // 稳定排序
        sort.Stable(sort.IntSlice(s))
        fmt.Println("是否排好序了", sort.IntsAreSorted(s))
        fmt.Println("查找是否存在", sort.SearchInts(s, 5))
        fmt.Println(s)

        str := []string{"s", "f", "d", "c", "r", "a"}
        sort.Strings(str)
        fmt.Println(str)

        flo := []float64{1.33, 4.78, 0.11, 6.77, 8.99, 4.22}
        sort.Float64s(flo)
        fmt.Println(flo)
}
</code></pre>
<p>看下输出</p>
<pre><code>是否排好序了 false


是否排好序了 true
查找是否存在 4



</code></pre>
<p>sort 本身不是稳定排序,需要稳定排序使用<code>sort.Stable</code>,同时排序默认是升序,降序可使用<code>sort.Reverse</code></p>
<h4 id="自定义-less-排序比较器">自定义 Less 排序比较器</h4>
<p>如果我们需要进行的排序的内容是一些复杂的结构,例如下面的栗子,是个结构体,根据结构体中的某一个属性进行排序,这时候可以通过自定义 Less 比较器实现</p>
<p>使用 <code>sort.Slice</code>,<code>sort.Slice</code>中提供了 less 函数,我们,可以自定义这个函数,然后通过<code>sort.Slice</code>进行排序,<code>sort.Slice</code>不是稳定排序,稳定排序可使用<code>sort.SliceStable</code></p>
<pre><code class="language-go">type Person struct {
        Name string
        Ageint
}

func TestSortSlice(t *testing.T) {
        people := []Person{
                {"Bob", 31},
                {"John", 42},
                {"Michael", 17},
                {"Jenny", 26},
        }

        sort.Slice(people, func(i, j int) bool {
                return people.Age &lt; people.Age
        })
        // Age正序
        fmt.Println(people)
        // Age倒序
        sort.Slice(people, func(i, j int) bool {
                return people.Age &gt; people.Age
        })
        fmt.Println(people)

        // 稳定排序
        sort.SliceStable(people, func(i, j int) bool {
                return people.Age &gt; people.Age
        })
        fmt.Println(people)
}
</code></pre>
<p>看下输出</p>
<pre><code>[{Michael 17} {Jenny 26} {Bob 31} {John 42}]
[{John 42} {Bob 31} {Jenny 26} {Michael 17}]
[{John 42} {Bob 31} {Jenny 26} {Michael 17}]
</code></pre>
<h4 id="自定义数据结构的排序">自定义数据结构的排序</h4>
<p>对自定义结构的排序,除了可以自定义 Less 排序比较器之外,sort 包中也提供了<code>sort.Interface</code>接口,我们只要实现了<code>sort.Interface</code>中提供的三个方法,即可通过 sort 包内的函数完成排序,查找等操作</p>
<pre><code class="language-go">// An implementation of Interface can be sorted by the routines in this package.
// The methods refer to elements of the underlying collection by integer index.
type Interface interface {
        // Len is the number of elements in the collection.
        Len() int

        // Less reports whether the element with index i
        // must sort before the element with index j.
        //
        // If both Less(i, j) and Less(j, i) are false,
        // then the elements at index i and j are considered equal.
        // Sort may place equal elements in any order in the final result,
        // while Stable preserves the original input order of equal elements.
        //
        // Less must describe a transitive ordering:
        //- if both Less(i, j) and Less(j, k) are true, then Less(i, k) must be true as well.
        //- if both Less(i, j) and Less(j, k) are false, then Less(i, k) must be false as well.
        //
        // Note that floating-point comparison (the &lt; operator on float32 or float64 values)
        // is not a transitive ordering when not-a-number (NaN) values are involved.
        // See Float64Slice.Less for a correct implementation for floating-point values.
        Less(i, j int) bool

        // Swap swaps the elements with indexes i and j.
        Swap(i, j int)
}
</code></pre>
<p>来看下如何使用</p>
<pre><code class="language-go">type ByAge []Person

func (a ByAge) Len() int         { return len(a) }
func (a ByAge) Swap(i, j int)      { a, a = a, a }
func (a ByAge) Less(i, j int) bool { return a.Age &lt; a.Age }

func TestSortStruct(t *testing.T) {
        people := []Person{
                {"Bob", 31},
                {"John", 42},
                {"Michael", 17},
                {"Jenny", 26},
        }

        sort.Sort(ByAge(people))
        fmt.Println(people)
}
</code></pre>
<p>输出</p>
<pre><code>[{Michael 17} {Jenny 26} {Bob 31} {John 42}]
</code></pre>
<p>当然 sort 包中已经实现的<code>[]int, []float, []string</code> 这几种类型的排序也是实现了<code>sort.Interface</code>接口</p>
<p>对于上面的三种排序,第一种和第二种基本上就能满足我们的额需求了,不过第三种灵活性更强。</p>
<h3 id="分析下源码">分析下源码</h3>
<p>先来看下什么是稳定性排序</p>
<p>栗如:对一个数组进行排序,如果里面有重复的数据,排完序时候,相同的数据的相对索引位置没有发生改变,那么就是稳定排序。</p>
<p>也就是里面有两个5,5。排完之后第一个5还在最前面,没有和后面的重复数据5发生过位置的互换,那么这就是稳定排序。</p>
<h4 id="不稳定排序">不稳定排序</h4>
<p>sort 中的排序算法用到了,quickSort(快排),heapSort(堆排序),insertionSort(插入排序),shellSort(希尔排序)</p>
<p>先来分析下这几种排序算法的使用</p>
<p>可以看下调用 Sort 进行排序,最终都会调用 quickSort</p>
<pre><code class="language-go">func Sort(data Interface) {
        n := data.Len()
        quickSort(data, 0, n, maxDepth(n))
}
</code></pre>
<p>再来看下 quickSort 的实现</p>
<pre><code class="language-go">func quickSort(data Interface, a, b, maxDepth int) {
        // 切片长度大于12的时候使用快排
        for b-a &gt; 12 { // Use ShellSort for slices &lt;= 12 elements
                // maxDepth 返回快速排序应该切换的阈值
                // 进行堆排序
                // 当 maxDepth为0的时候进行堆排序
                if maxDepth == 0 {
                        heapSort(data, a, b)
                        return
                }
                maxDepth--
                // doPivot 是快排核心算法,它取一点为轴,把不大于轴的元素放左边,大于轴的元素放右边,返回小于轴部分数据的最后一个下标,以及大于轴部分数据的第一个下标
                // 下标位置 a...mlo,pivot,mhi...b
                // data &lt;= data
                // data &gt; data
                // 和中位数一样的数据就不用在进行交换了,维护这个范围值能减少数据的次数
                mlo, mhi := doPivot(data, a, b)
                // 避免递归过深
                // 循环是比递归节省时间的,如果有大规模的子节点,让小的先递归,达到了 maxDepth 也就是可以触发堆排序的条件了,然后使用堆排序进行排序
                if mlo-a &lt; b-mhi {
                        quickSort(data, a, mlo, maxDepth)
                        a = mhi // i.e., quickSort(data, mhi, b)
                } else {
                        quickSort(data, mhi, b, maxDepth)
                        b = mlo // i.e., quickSort(data, a, mlo)
                }
        }
        // 如果切片的长度大于1小于等于12的时候,使用 shell 排序
        if b-a &gt; 1 {
                // Do ShellSort pass with gap 6
                // It could be written in this simplified form cause b-a &lt;= 12
                // 这里先做一轮shell 排序
                for i := a + 6; i &lt; b; i++ {
                        if data.Less(i, i-6) {
                                data.Swap(i, i-6)
                        }
                }
                // 进行插入排序
                insertionSort(data, a, b)
        }
}

// maxDepth 返回快速排序应该切换的阈值
// 进行堆排序
func maxDepth(n int) int {
        var depth int
        for i := n; i &gt; 0; i &gt;&gt;= 1 {
                depth++
        }
        return depth * 2
}

// doPivot 是快排核心算法,它取一点为轴,把不大于轴的元素放左边,大于轴的元素放右边,返回小于轴部分数据的最后一个下标,以及大于轴部分数据的第一个下标
// 下标位置 lo...midlo,pivot,midhi...hi
// data &lt;= data
// data &gt; data
func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
        m := int(uint(lo+hi) &gt;&gt; 1) // Written like this to avoid integer overflow.
        // 这里用到了 Tukey's ninther 算法,文章链接 https://www.johndcook.com/blog/2009/06/23/tukey-median-ninther/
        // 通过该算法求出中位数
        if hi-lo &gt; 40 {
                // Tukey's ``Ninther,'' median of three medians of three.
                s := (hi - lo) / 8
                medianOfThree(data, lo, lo+s, lo+2*s)
                medianOfThree(data, m, m-s, m+s)
                medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
        }

        // 求出中位数 data &lt;= data &lt;= data
        medianOfThree(data, lo, m, hi-1)

        // Invariants are:
        //        data = pivot (set up by ChoosePivot)
        //        data &lt; pivot
        //        data &lt;= pivot
        //        data unexamined
        //        data &gt; pivot
        //        data &gt;= pivot
        // 中位数
        pivot := lo
        a, c := lo+1, hi-1

        // 处理使 data &lt; pivot
        for ; a &lt; c &amp;&amp; data.Less(a, pivot); a++ {
        }
        b := a
        for {
                // 处理使 data &lt;= pivot
                for ; b &lt; c &amp;&amp; !data.Less(pivot, b); b++ {
                }
                // 处理使 data &gt; pivot
                for ; b &lt; c &amp;&amp; data.Less(pivot, c-1); c-- { // data &gt; pivot
                }
                // 左边和右边重合或者已经在右边的右侧
                if b &gt;= c {
                        break
                }
                // data &gt; pivot; data &lt;= pivot
                // 左侧的数据大于右侧,交换,然后接着排序
                data.Swap(b, c-1)
                b++
                c--
        }
        // If hi-c&lt;3 then there are duplicates (by property of median of nine).
        // Let's be a bit more conservative, and set border to 5.
        // 如果 hi-c&lt;3 则存在重复项(按中位数为 9 的属性)。
        // 让我们稍微保守一点,将边框设置为 5。

        // 因为c为划分pivot的大小的临界值,所以在9值划分时,正常来说,应该是两边各4个
        // 由于左边是&lt;=,多了个相等的情况,所以5,3分布,也是没有问题
        // 如果hi-c&lt;3,c的值明显偏向于hi,说明有多个和pivot重复值
        // 为了更保守一点,所以设置为5(反正只是多校验一次而已)
        protect := hi-c &lt; 5
        // 即便大于等于5,也可能是因为元素总值很多,所以对比hi-c是否小于总数量的1/4
        if !protect &amp;&amp; hi-c &lt; (hi-lo)/4 {
                // 用一些特殊的点和中间数进行比较
                dups := 0
                // 处理使 data = pivot
                if !data.Less(pivot, hi-1) {
                        data.Swap(c, hi-1)
                        c++
                        dups++
                }
                // 处理使 data = pivot
                if !data.Less(b-1, pivot) {
                        b--
                        dups++
                }
                // m-lo = (hi-lo)/2 &gt; 6
                // b-lo &gt; (hi-lo)*3/4-1 &gt; 8
                // ==&gt; m &lt; b ==&gt; data &lt;= pivot
                if !data.Less(m, pivot) { // data = pivot
                        data.Swap(m, b-1)
                        b--
                        dups++
                }
                // 如果上面的 if 进入了两次, 就证明现在是偏态分布(也就是左右不平衡的)
                protect = dups &gt; 1
        }
        // 不平衡,接着进行处理
        // 这里划分的是&lt;pivot和=pivot的两组
        if protect {
                // Protect against a lot of duplicates
                // Add invariant:
                //        data unexamined
                //        data = pivot
                for {
                        // 处理使 data == pivot
                        for ; a &lt; b &amp;&amp; !data.Less(b-1, pivot); b-- {
                        }
                        // 处理使 data &lt; pivot
                        for ; a &lt; b &amp;&amp; data.Less(a, pivot); a++ {
                        }
                        if a &gt;= b {
                                break
                        }
                        // data == pivot; data &lt; pivot
                        data.Swap(a, b-1)
                        a++
                        b--
                }
        }
        // 交换中位数到中间
        data.Swap(pivot, b-1)
        return b - 1, c
}
</code></pre>
<p>对于这几种排序算法的使用,sort 包中是混合使用的</p>
<p>1、如果切片长度大于12的时候使用快排,使用快排的时候,如果满足了使用堆排序的条件没这个排序对于后面的数据的处理,又会转换成堆排序;</p>
<p>2、切片长度小于12了,就使用 shell 排序,shell 排序只处理一轮数据,后面数据的排序使用插入排序;</p>
<p>堆排序和插入排序就是正常的排序处理了</p>
<pre><code class="language-go">// insertionSort sorts data using insertion sort.
// 插入排序
func insertionSort(data Interface, a, b int) {
        for i := a + 1; i &lt; b; i++ {
                for j := i; j &gt; a &amp;&amp; data.Less(j, j-1); j-- {
                        data.Swap(j, j-1)
                }
        }
}

// 堆排序
func heapSort(data Interface, a, b int) {
        first := a
        lo := 0
        hi := b - a

        // Build heap with greatest element at top.
        for i := (hi - 1) / 2; i &gt;= 0; i-- {
                siftDown(data, i, hi, first)
        }

        // Pop elements, largest first, into end of data.
        for i := hi - 1; i &gt;= 0; i-- {
                data.Swap(first, first+i)
                siftDown(data, lo, i, first)
        }
}
</code></pre>
<h4 id="稳定排序">稳定排序</h4>
<p>sort 包中也提供了稳定的排序,通过调用<code>sort.Stable</code>来实现</p>
<pre><code class="language-go">// It makes one call to data.Len to determine n, O(n*log(n)) calls to
// data.Less and O(n*log(n)*log(n)) calls to data.Swap.
func Stable(data Interface) {
        stable(data, data.Len())
}

func stable(data Interface, n int) {
        // 定义切片块的大小
        blockSize := 20 // must be &gt; 0
        a, b := 0, blockSize
        // 如果切片长度大于块的大小,分多次对每个块中进行排序   
        for b &lt;= n {
                insertionSort(data, a, b)
                a = b
                b += blockSize
        }
        insertionSort(data, a, n)

        // 如果有多个块,对排好序的块进行合并操作
        for blockSize &lt; n {
                a, b = 0, 2*blockSize
                for b &lt;= n {
                        symMerge(data, a, a+blockSize, b)
                        a = b
                        b += 2 * blockSize
                }
                if m := a + blockSize; m &lt; n {
                        symMerge(data, a, m, n)
                }
                // block 每次循环扩大两倍, 直到比元素的总个数大,就结束
                blockSize *= 2
        }
}

func symMerge(data Interface, a, m, b int) {
        // 如果只有一个元素避免没必要的递归,这里直接插入
        // 处理左边部分
        if m-a == 1 {
                // 使用二分查找查找最低索引 i
                // 这样 data &gt;= data for m &lt;= i &lt; b.
                // 如果不存在这样的索引,则使用 i == b 退出搜索循环。
                i := m
                j := b
                for i &lt; j {
                        h := int(uint(i+j) &gt;&gt; 1)
                        if data.Less(h, a) {
                                i = h + 1
                        } else {
                                j = h
                        }
                }
                // Swap values until data reaches the position before i.
                for k := a; k &lt; i-1; k++ {
                        data.Swap(k, k+1)
                }
                return
        }

        // 同上
        // 处理右边部分
        if b-m == 1 {
                // Use binary search to find the lowest index i
                // such that data &gt; data for a &lt;= i &lt; m.
                // Exit the search loop with i == m in case no such index exists.
                i := a
                j := m
                for i &lt; j {
                        h := int(uint(i+j) &gt;&gt; 1)
                        if !data.Less(m, h) {
                                i = h + 1
                        } else {
                                j = h
                        }
                }
                // Swap values until data reaches the position i.
                for k := m; k &gt; i; k-- {
                        data.Swap(k, k-1)
                }
                return
        }

        for start &lt; r {
                c := int(uint(start+r) &gt;&gt; 1)
                if !data.Less(p-c, c) {
                        start = c + 1
                } else {
                        r = c
                }
        }

        end := n - start
        if start &lt; m &amp;&amp; m &lt; end {
                rotate(data, start, m, end)
        }
        // 递归的进行归并操作
        if a &lt; start &amp;&amp; start &lt; mid {
                symMerge(data, a, start, mid)
        }
        if mid &lt; end &amp;&amp; end &lt; b {
                symMerge(data, mid, end, b)
        }
}
</code></pre>
<p>对于稳定排序,用到了插入排序和归并排序</p>
<p>1、首先会将数据按照每20个一组进行分块,对每个块中的数据使用插入排序完成排序;</p>
<p>2、然后下面使用归并排序,对排序的数据块进行两两归并排序,完成一次排序,扩大数据块为之前的2倍,直到完成所有的排序。</p>
<h4 id="查找">查找</h4>
<p>sort 中的 查找功能最终是调用 search 函数来实现的</p>
<pre><code class="language-go">func SearchInts(a []int, x int) int {
        return Search(len(a), func(i int) bool { return a &gt;= x })
}

// 使用二分查找
func Search(n int, f func(int) bool) int {
        // Define f(-1) == false and f(n) == true.
        // Invariant: f(i-1) == false, f(j) == true.
        i, j := 0, n
        for i &lt; j {
                // 二分查找
                h := int(uint(i+j) &gt;&gt; 1) // avoid overflow when computing h
                // i ≤ h &lt; j
                if !f(h) {
                        i = h + 1 // preserves f(i-1) == false
                } else {
                        j = h // preserves f(j) == true
                }
        }
        // i == j, f(i-1) == false, and f(j) (= f(i)) == true=&gt;answer is i.
        return i
}
</code></pre>
<p>sort 中查找相对比较简单,使用的是二分查找</p>
<h4 id="interface">Interface</h4>
<p>sort 包提供了 Interface 的接口,我们可以自定义数据结构,然后实现 Interface 对应的接口,就能使用 sort 包中的方法</p>
<pre><code class="language-go">type Interface interface {
        Len() int

        Less(i, j int) bool

        Swap(i, j int)
}
</code></pre>
<p>看源码可以看到 sort 包中已有的对 []int 等数据结构的排序,也是实现了 Interface</p>
<pre><code class="language-go">// Convenience types for common cases

// IntSlice attaches the methods of Interface to []int, sorting in increasing order.
type IntSlice []int

func (x IntSlice) Len() int         { return len(x) }
func (x IntSlice) Less(i, j int) bool { return x &lt; x }
func (x IntSlice) Swap(i, j int)      { x, x = x, x }
</code></pre>
<p>这种思路挺好的,之后可以借鉴下,对于可变部分提供抽象接口,让用户根据自己的场景有实现。</p>
<p>对于基础的排序,查找只要实现了 Interface 的方法,就能拥有这些基础的能力了。</p>
<h3 id="总结">总结</h3>
<p>sort 对于排序算法的实现,是结合了多种算法,最终实现了一个高性能的排序算法</p>
<p>抽象出了 IntSlice 接口,用户可以自己去实现对应的方法,然后就能拥有 sort 中提供的能力了</p>
<h3 id="参考">参考</h3>
<p>【文中示例代码】https://github.com/boilingfrog/Go-POINT/blob/master/golang/sort/sort_test.go<br>
【Golang sort 排序】https://blog.csdn.net/K346K346/article/details/118314382<br>
【John Tukey’s median of medians】https://www.johndcook.com/blog/2009/06/23/tukey-median-ninther/<br>
【code_reading】https://github.com/Junedayday/code_reading/blob/master/sort/sort.go<br>
【go中的sort包】https://boilingfrog.github.io/2022/03/06/go中的sort包/</p><br><br>
来源:https://www.cnblogs.com/ricklz/p/15972396.html
頁: [1]
查看完整版本: go 中 sort 如何排序,源码解读