【算法专题】分治 - 快速排序
作者:mmseoamin日期:2024-02-28

分治 - 快速排序

  • 分治 - 快速排序
    • 1. 颜色分类
    • 2. 排序数组(快速排序)
    • 3. 数组中的第K个最大元素
    • 4. 库存管理Ⅲ
    • 5. 排序数组(归并排序)
    • 6. 交易逆序对的总数
    • 7. 计算右侧小于当前元素的个数
    • 8. 翻转对

      分治 - 快速排序

      1. 颜色分类

      做题链接 -> Leetcode -75.颜色分类

      题目:给定一个包含红色、白色和蓝色、共 n 个元素数组 nums ,原地对它们进行排序,使得相同颜色的元素相邻,并按照红色、白色、蓝色顺序排列。

      我们使用整数 0、 1 和 2 分别表示红色、白色和蓝色。

      必须在不使用库内置的 sort 函数的情况下解决这个问题。

      示例 1:

      输入:nums = [2, 0, 2, 1, 1, 0]

      输出:[0, 0, 1, 1, 2, 2]

      示例 2:

      输入:nums = [2, 0, 1]

      输出:[0, 1, 2]

      提示:

      n == nums.length

      1 <= n <= 300

      nums[i] 为 0、1 或 2

      思路:快排思想,三指针法使数组分三块。类比数组分两块的算法思想,这里是将数组分成三块,那么我们可以再添加⼀个指针,实现数组分三块。

      设数组大小为 n ,定义三个指针 left, cur, right :

      • left :用来标记 0(红色) 序列的末尾,因此初始化为 -1 ;
      • cur :用来扫描数组,初始化为 0 ;
      • right :用来标记 2(蓝色) 序列的起始位置,因此初始化为 n 。

        在 cur 往后扫描的过程中,保证:

        • [0, left] 内的元素都是 0(红色) ;
        • [left + 1, cur - 1] 内的元素都是 1(白色) ;
        • [cur, right - 1] 内的元素是待定元素;
        • [right, n] 内的元素都是 2(蓝色) .

          代码如下:

          		class Solution {
          		public:
          		    void sortColors(vector& nums) 
          		    {
          		        // 使用三指针将数组分为三块,最终分为以下三个模块:
          		        // [0, left] 表示 0(红色) 序列;
          		        // [left + 1, right - 1] 表示 1(白色) 序列;
          		        // [right, numsSize - 1] 表示 2(蓝色) 序列。     
          		        int cur = 0, left = -1, right = nums.size();
          		        while(cur < right)
          		        {
          		            if(nums[cur] == 0) swap(nums[++left], nums[cur++]);
          		            else if(nums[cur] == 1) cur++;
          		            else swap(nums[--right], nums[cur]);
          		        }
          		    }
          		};
          

          2. 排序数组(快速排序)

          做题链接 -> Leetcode -912.排序数组

          题目:给你一个整数数组 nums,请你将该数组升序排列。

          示例 1:

          输入:nums = [5, 2, 3, 1]

          输出:[1, 2, 3, 5]

          示例 2:

          输入:nums = [5, 1, 1, 2, 0, 0]

          输出:[0, 0, 1, 1, 2, 5]

          提示:

          1 <= nums.length <= 5 * 10^4

          5 * 10^4 <= nums[i] <= 5 * 10^4

          由于思路比较明显,使用快速选择算法,递归处理选取一个基准值 key 将数组分为三块,下面直接看代码:

          		class Solution {
          		public:
          		    vector sortArray(vector& nums) 
          		    {
          		        // 种下一个随机数种子
          		        srand(time(nullptr));
          		        
          		        // 快速选择算法,将数组划分为三个区间
          		        my_qsort(nums, 0, nums.size() - 1);
          		        return nums;
          		    }
          		
          		    void my_qsort(vector& nums, int l, int r)
          		    {
          		        if(l >= r) return;
          		
          		        // 将数组分三块
          		        int key = getRandom(nums, l, r);
          		        int i = l, left = l - 1, right = r + 1;
          		        while(i < right)
          		        {
          		            if(nums[i] > key) swap(nums[i], nums[--right]);
          		            else if(nums[i] == key) ++i;
          		            else swap(nums[++left], nums[i++]);
          		        }
          		
          		        // [l, left] [left + 1, right - 1] [right, r]
          		        my_qsort(nums, l, left);
          		        my_qsort(nums, right, r);
          		    }
          		
          		    // 获取数组中随机一个数
          		    // 让随机数 % 上区间大小,然后加上区间的左边界
          		    int getRandom(vector& nums, int left, int right)
          		    {
          		        return nums[rand() % (right - left + 1) + left];
          		    }
          		};
          

          3. 数组中的第K个最大元素

          题目链接 -> Leetcode -215.数组中的第K个最大元素

          Leetcode -215.数组中的第K个最大元素

          题目:给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。

          请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

          你必须设计并实现时间复杂度为 O(n) 的算法解决此问题。

          示例 1:

          输入: [3, 2, 1, 5, 6, 4] , k = 2

          输出 : 5

          示例 2 :

          输入 : [3, 2, 3, 1, 2, 4, 5, 5, 6] , k = 4

          输出 : 4

          提示:

          • 1 <= k <= nums.length <= 10^5
          • 10^4 <= nums[i] <= 10^4

            思路是使用快排思想,将数组分为三块,然后分三种情况讨论,具体思路参考代码解析;

            代码如下:

            		class Solution {
            		public:
            		    int findKthLargest(vector& nums, int k)
            		    {
            		        srand(time(nullptr));
            		        return FindMaxTopk(nums, 0, nums.size() - 1, k);
            		    }
            		
            		    int FindMaxTopk(vector& nums, int l, int r, int k)
            		    {
            		        if (l == r) return nums[l];
            		
            		        // 根据 key 将数组分为三块
            		        int key = getRandom(nums, l, r);
            		        int i = l, left = l - 1, right = r + 1;
            		        while (i < right)
            		        {
            		            if (nums[i] < key) swap(nums[++left], nums[i++]);
            		            else if (nums[i] == key) i++;
            		            else swap(nums[--right], nums[i]);
            		        }
            		
            		        // [l, left] [left + 1, right - 1] [right, r]
            		        int part2 = right - left - 1, part3 = r - right + 1;
            		
            		        // 分情况讨论
            		        // 情况1、区间3的个数大于等于k,那么目标值一定在区间3
            		        if (part3 >= k) return FindMaxTopk(nums, right, r, k);
            		
            		        // 情况2、区间2+区间3的个数大于等于k,目标值一定在区间2,即一定是 key
            		        else if (part2 + part3 >= k) return key;
            		
            		        // 情况3、如果不满足上面情况,则目标值一定在区间1
            		        else return FindMaxTopk(nums, l, left, k - part2 - part3);
            		    }
            		
            		    // 获取数组内的一个随机值
            		    int getRandom(vector& nums, int left, int right)
            		    {
            		        return nums[rand() % (right - left + 1) + left];
            		    }
            		};
            

            4. 库存管理Ⅲ

            题目链接 -> Leetcode -LCR 159.库存管理Ⅲ

            Leetcode -LCR 159.库存管理Ⅲ

            题目:仓库管理员以数组 stock 形式记录商品库存表,其中 stock[i] 表示对应商品库存余量。

            请返回库存余量最少的 cnt 个商品余量,返回 顺序不限。

            示例 1:

            输入:stock = [2, 5, 7, 4], cnt = 1

            输出:[2]

            示例 2:

            输入:stock = [0, 2, 3, 6], cnt = 2

            输出:[0, 2] 或[2, 0]

            提示:

            0 <= cnt <= stock.length <= 10000

            0 <= stock[i] <= 10000

            思路:与上题思路类似;在快排中,当我们把数组「分成三块」之后: [l, left] [left + 1, right - 1] [right, r] ,我们可以通过计算每一个区间内元素的「个数」,进而推断出最小的 k 个数在哪些区间里面。那么我们可以直接去「相应的区间」继续划分数组即可。

            代码如下:

            		class Solution {
            		public:
            		
            		    void my_qsort(vector& arr, int l, int r, int k)
            		    {
            		        if(l >= r) return;
            		
            		        // 根据 key 值分区间
            		        int key = getRandom(arr, l, r);
            		        int i = l, left = l - 1, right = r + 1;
            		        while(i < right)
            		        {
            		            if(arr[i] < key) swap(arr[++left], arr[i++]);
            		            else if(arr[i] == key) i++;
            		            else swap(arr[--right], arr[i]);
            		        }
            		
            		        // 根据元素个数分情况讨论
            		        // [l, left] [left + 1][right - 1] [right, r]
            		        int part1 = left - l + 1, part2 = right - left - 1;
            		        if(part1 >= k) my_qsort(arr, l, left, k);
            		        else if(part1 + part2 >= k) return;
            		        else my_qsort(arr, right, r, k - part1 - part2);
            		    }
            		
            		    // 选取基准值
            		    int getRandom(vector& arr, int left, int right)
            		    {
            		        return arr[rand() % (right - left + 1) + left];
            		    }
            		
            		    vector inventoryManagement(vector& stock, int cnt) 
            		    {
            		        srand(time(nullptr));
            		        // 快速选择算法,将数组分为三个区间,选择基准值key,比key小的元素全扔到左边
            		        my_qsort(stock, 0, stock.size() - 1, cnt);
            		        return vector(stock.begin(), stock.begin() + cnt);
            		    }
            		};
            

            5. 排序数组(归并排序)

            题目链接 -> Leetcode -912.排序数组(归并排序)

            Leetcode -912.排序数组(归并排序)

            题目:给你一个整数数组 nums,请你将该数组升序排列。

            示例 1:

            输入:nums = [5, 2, 3, 1]

            输出:[1, 2, 3, 5]

            示例 2:

            输入:nums = [5, 1, 1, 2, 0, 0]

            输出:[0, 0, 1, 1, 2, 5]

            提示:

            • 1 <= nums.length <= 5 * 10^4
            • 5 * 10^4 <= nums[i] <= 5 * 10^4

              思路:归并排序的流程充分的体现了「分而治之」的思想,大体过程分为两步:

              • 分:将数组一分为二为两部分,一直分解到数组的长度为 1 ,使整个数组的排序过程被分为「左半部分排序」 + 「右半部分排序」;
              • 治:将两个较短的「有序数组合并成⼀个长的有序数组」,一直合并到最初的长度

                代码如下:

                		class Solution
                		{
                		    vector tmp;
                		public:
                		    vector sortArray(vector& nums)
                		    {
                		        tmp.resize(nums.size());
                		        mergeSort(nums, 0, nums.size() - 1);
                		        return nums;
                		    }
                		
                		    void mergeSort(vector& nums, int left, int right)
                		    {
                		        if (left >= right) return;
                		
                		        int mid = left + (right - left) / 2;
                		
                		        // [left, mid] [mid + 1, right]
                		        mergeSort(nums, left, mid);
                		        mergeSort(nums, mid + 1, right);
                		
                		        // 合并两个区间
                		        // vector tmp(right - left + 1);  // 可以在全局定义,提高效率
                		        int i = 0, cur1 = left, cur2 = mid + 1;
                		
                		        while (cur1 <= mid && cur2 <= right)
                		            tmp[i++] = nums[cur1] <= nums[cur2] ? nums[cur1++] : nums[cur2++];
                		
                		        while (cur1 <= mid) tmp[i++] = nums[cur1++];
                		        while (cur2 <= right) tmp[i++] = nums[cur2++];
                		
                		        // 更新原数组
                		        for (int i = left; i <= right; i++)
                		            nums[i] = tmp[i - left];
                		    }
                		};
                

                6. 交易逆序对的总数

                题目链接 -> Leetcode -LCR 170.交易逆序对的总数

                Leetcode -LCR 170.交易逆序对的总数

                题目:在股票交易中,如果前一天的股价高于后一天的股价,则可以认为存在一个「交易逆序对」。请设计一个程序,输入一段时间内的股票交易记录 record,返回其中存在的「交易逆序对」总数。

                示例 1:

                输入:record = [9, 7, 5, 4, 6]

                输出:8

                解释:交易中的逆序对为(9, 7), (9, 5), (9, 4), (9, 6), (7, 5), (7, 4), (7, 6), (5, 4)。

                限制:

                • 0 <= record.length <= 50000

                  思路:用归并排序求逆序数,主要就是在归并排序的合并过程中统计出逆序对的数量,也就是在合并两个有序序列的过程中,能够快速求出逆序对的数量。

                  1. 为什么可以利用归并排序?

                  如果我们将数组从中间划分成两个部分,那么我们可以将逆序对产生的方式划分成三组:

                  • 逆序对中两个元素:全部从左数组中选择
                  • 逆序对中两个元素:全部从右数组中选择
                  • 逆序对中两个元素:一个选左数组另一个选右数组

                    根据排列组合的分类相加原理,三种情况下产生的逆序对的总和,正好等于总的逆序对数量。

                    而这个思路正好匹配归并排序的过程:

                    • 先排序左数组;
                    • 再排序右数组;
                    • 左数组和右数组合⼆为一;

                      因此,我们可以利用归并排序的过程,先求出左半数组中逆序对的数量,再求出右半数组中逆序对的数量,最后求出一个选择左边,另一个选择右边情况下逆序对的数量,三者相加即可。

                      2. 为什么要这么做?

                      在归并排序合并的过程中,我们得到的是两个有序的数组。我们是可以利用数组的有序性,快速统计出逆序对的数量,而不是将所有情况都枚举出来。

                      最核心的问题,如何在合并两个有序数组的过程中,统计出逆序对的数量?合并两个有序序列时求逆序对的方法有两种:

                      1. 快速统计出某个数前面有多少个数比它大;
                      2. 快速统计出某个数后面有多少个数比它小;

                      代码如下:

                      		class Solution
                      		{
                      		    vector tmp;
                      		public:
                      		    int reversePairs(vector& nums)
                      		    {
                      		        tmp.resize(nums.size());
                      		        return mergeSort(nums, 0, nums.size() - 1);
                      		    }
                      		
                      		    int mergeSort(vector& nums, int left, int right)
                      		    {
                      		        if (left >= right) return 0;
                      		
                      		        // [left, mid] [mid + 1, right]
                      		        int mid = left + (right - left) / 2;
                      		
                      		        // 先统计两个区间各自的逆序对个数 + 排序
                      		        int ret = 0;
                      		        ret += mergeSort(nums, left, mid);
                      		        ret += mergeSort(nums, mid + 1, right);
                      		
                      		        // 两个区间每个区间选一个进行比较,因为比较时区间已经排序好,所以当cur1中出现第一次比cur2大的数时,cur1 后面的数都可以全部统计
                      		        int cur1 = left, cur2 = mid + 1, i = 0;
                      		        while (cur1 <= mid && cur2 <= right)
                      		        {
                      		            if (nums[cur1] > nums[cur2]) ret += mid - cur1 + 1;
                      		
                      		            tmp[i++] = nums[cur1] <= nums[cur2] ? nums[cur1++] : nums[cur2++];
                      		        }
                      		
                      		        // 处理细节,还没结束的指针后的数全放入tmp中
                      		        while (cur1 <= mid) tmp[i++] = nums[cur1++];
                      		        while (cur2 <= right) tmp[i++] = nums[cur2++];
                      		
                      		        // 拷贝回原数组
                      		        for (int i = left; i <= right; i++) nums[i] = tmp[i - left];
                      		
                      		        return ret;
                      		    }
                      		};
                      

                      7. 计算右侧小于当前元素的个数

                      题目链接 -> Leetcode -315.计算右侧小于当前元素的个数

                      Leetcode -315.计算右侧小于当前元素的个数

                      题目:给你一个整数数组 nums ,按要求返回一个新数组 counts 。数组 counts 有该性质: counts[i] 的值是 nums[i] 右侧小于 nums[i] 的元素的数量。

                      示例 1:

                      输入:nums = [5, 2, 6, 1]

                      输出:[2, 1, 1, 0]

                      解释:

                      5 的右侧有 2 个更小的元素(2 和 1)

                      2 的右侧仅有 1 个更小的元素(1)

                      6 的右侧有 1 个更小的元素(1)

                      1 的右侧有 0 个更小的元素

                      示例 2:

                      输入:nums = [-1]

                      输出:[0]

                      示例 3:

                      输入:nums = [-1, -1]

                      输出:[0, 0]

                      提示:

                      • 1 <= nums.length <= 10^5
                      • 10^4 <= nums[i] <= 10^4

                        思路:这一道题的解法与上一题的解法是类似的,但是这一道题要求的不是求总的个数,而是要返回一个数组,记录每一个元素的右边有多少个元素比自己小。

                        但是在我们归并排序的过程中,元素的下标是会跟着变化的,因此我们需要一个辅助数组,来将数组元素和对应的下标绑定在一起归并,也就是再归并元素的时候,顺势将下标也转移到对应的位置上。

                        代码如下:

                        		class Solution 
                        		{
                        		    // 将原数组的元素和下标绑定在一起,元素顺序改变时,对应的下标也跟着改变
                        		    vector tmpElement, tmpIndex;
                        		    vector index;
                        		    vector ret;
                        		    
                        		public:
                        		    vector countSmaller(vector& nums) 
                        		    {
                        		        ret.resize(nums.size());
                        		        index.resize(nums.size());
                        		
                        		        // 初始化下标
                        		        for(int i = 0; i < nums.size(); i++)
                        		            index[i] = i;
                        		
                        		        tmpElement.resize(nums.size());
                        		        tmpIndex.resize(nums.size());
                        		        
                        		        mergeSort(nums, 0, nums.size() - 1);
                        		        return ret;
                        		    }
                        		
                        		    void mergeSort(vector& nums, int left, int right)
                        		    {
                        		        if(left >= right) return;
                        		
                        		        int mid = left + (right - left) / 2;
                        		
                        		        // [left, mid] [mid + 1, right]
                        		        mergeSort(nums, left, mid);
                        		        mergeSort(nums, mid + 1, right);
                        		
                        		        int cur1 = left, cur2 = mid + 1, i = 0;
                        		        while(cur1 <= mid && cur2 <= right)
                        		        {
                        		            // index[cur1] 存的是 nums[cur1] 这个元素的原始下标
                        		            if(nums[cur1] > nums[cur2]) 
                        		                ret[index[cur1]] += right - cur2 + 1;
                        		
                        		            // 同步更新下标和元素
                        		            tmpIndex[i] = nums[cur1] > nums[cur2]? index[cur1] : index[cur2];
                        		            tmpElement[i++] = nums[cur1] > nums[cur2]? nums[cur1++] : nums[cur2++];
                        		        }
                        		
                        		        while(cur1 <= mid)
                        		        {
                        		            tmpIndex[i] = index[cur1];
                        		            tmpElement[i++] = nums[cur1++];
                        		        }
                        		        while(cur2 <= right)
                        		        {
                        		            tmpIndex[i] = index[cur2];
                        		            tmpElement[i++] = nums[cur2++];
                        		        }
                        		
                        		        // 同步拷贝下标和元素
                        		        for(int j = left; j <= right; j++)
                        		        {
                        		            index[j] = tmpIndex[j - left];
                        		            nums[j] = tmpElement[j - left];
                        		        }
                        		    }
                        		};
                        

                        8. 翻转对

                        题目链接 -> Leetcode -493.翻转对

                        Leetcode -493.翻转对

                        题目:给定一个数组 nums ,如果 i < j 且 nums[i] > 2 * nums[j] 我们就将(i, j) 称作一个重要翻转对。

                        你需要返回给定数组中的重要翻转对的数量。

                        示例 1:

                        输入: [1, 3, 2, 3, 1]

                        输出 : 2

                        示例 2 :

                        输入 : [2, 4, 3, 5, 1]

                        输出 : 3

                        注意 :

                        给定数组的长度不会超过50000。

                        输入数组中的所有数字都在32位整数的表示范围内。

                        思路:翻转对和逆序对的定义大同小异,逆序对是前面的数要大于后面的数。而翻转对是前面的⼀个数要大于后面某个数的两倍。因此,我们依旧可以用归并排序的思想来解决这个问题。

                        大思路与求逆序对的思路一样,就是利用归并排序的思想,将求整个数组的翻转对的数量,转换成三部分:左半区间翻转对的数量,右半区间翻转对的数量,一左一右选择时翻转对的数量。重点就是在合并区间过程中,如何计算出翻转对的数量。

                        例如 left = [4, 5, 6] right = [3, 4, 5] 时,如果是归并排序的话,我们需要计算 left 数组中有多少个能与 3 组成翻转对。但是我们要遍历到最后⼀个元素 6 才能确定,时间复杂度较高。因此我们需要在归并排序之前完成翻转对的统计。

                        下面以⼀个示例来模仿两个有序序列如何快速求出翻转对的过程:假定已经有两个已经有序的序列 left = [4, 5, 6] right = [1, 2, 3] ;用两个指针 cur1 和 cur2 遍历两个数组

                        • 对于任意给定的 left[cur1] 而言,我们不断地向右移动 cur2,直到 left[cur1] <= 2 * right[cur2]。此时对于 right 数组而言,cur2 之前的元素全部都可以与 left[cur1] 构成翻转对。
                        • 随后,我们再将 cur1 向右移动⼀个单位,此时 cur2 指针并不需要回退(因为 left 数组是升序的)依旧往右移动直到 left[cur1] <= 2 * right[cur2]。不断重复这样的过程,就能够求出所有左右端点分别位于两个子数组的翻转对数目。

                          由于两个指针最后都是不回退的的扫描到数组的结尾,因此两个有序序列求出翻转对的时间复杂度是 O(N).

                          综上所述,我们可以利用归并排序的过程,将求一个数组的翻转对转换成求左数组的翻转对数量 + 右数组中翻转对的数量 + 左右数组合并时翻转对的数量。

                          代码如下:

                          		class Solution 
                          		{
                          		    vector tmp;
                          		public:
                          		    int reversePairs(vector& nums) 
                          		    {      
                          		        tmp.resize(nums.size());
                          		        return mergeSort(nums, 0, nums.size() - 1);
                          		    }
                          		
                          		    int mergeSort(vector& nums, int left, int right)
                          		    {
                          		        if(left >= right) return 0;
                          		
                          		        // 1.根据中间元素划分区间
                          		        int mid = left + (right - left) / 2;
                          		
                          		        // 2. 先计算左右区间的翻转对
                          		        // [left, mid] [mid + 1, right]
                          		        int ret = 0;
                          		        ret += mergeSort(nums, left, mid);
                          		        ret += mergeSort(nums, mid + 1, right);
                          		
                          		        // 3.先利用左右区间有序的性质计算翻转对的数量
                          		        int cur1 = left, cur2 = mid + 1, i = 0;
                          		        while(cur1 <= mid)
                          		        {
                          		            while(cur2 <= right && nums[cur2] >= nums[cur1] / 2.0) cur2++;
                          		
                          		            ret += right - cur2 + 1;
                          		            cur1++;
                          		        }
                          		
                          		        // 4.合并归并区间
                          		        cur1 = left, cur2 = mid + 1;
                          		        while(cur1 <= mid && cur2 <= right)
                          		            tmp[i++] = nums[cur2] > nums[cur1]? nums[cur2++] : nums[cur1++];
                          		
                          		        while(cur1 <= mid) tmp[i++] = nums[cur1++];
                          		        while(cur2 <= right) tmp[i++] = nums[cur2++];
                          		
                          		        for(int j = left; j <= right; j++)
                          		            nums[j] = tmp[j - left];
                          		        
                          		        return ret;
                          		
                          		    }
                          		};