赞
踩
快速选择排序 Quick select 解决Top K 问题
topK问题是实际应用中涉及面较广的一个抽象问题,譬如:从20亿个数字的文本中,找出最大的前100个。
引入
看到这个问题你可能自然而然的想到了排序,无论是平均时间复杂度为 O(NlogN)
的快排,还是时间复杂度为 O(NlogN)
的归并排序和堆排序都可以,但问题是如果 N 很大呢?
有没有一种方法不需要对所有元素进行排序呢?
且看使用冒泡排序或者选择排序,那么时间复杂度就是 O(Nk)
了,和刚刚提到的方法哪个更优呢,这取决于 logN
和k
的大小。
k-(pivot-left+1)
即可tip:
注意基准值所在位置 pivot_idx 不要再进入递归中
const swap=(arr,a,b)=>{ let temp=arr[a]; arr[a]=arr[b]; arr[b]=temp; }; const partition=(arr,k,left,right)=>{ let pivot=left,lessThan=left; if(left===right) return left; for(let i=left;i<=right;i++){ if(arr[i]<arr[pivot]){ lessThan++; swap(arr,lessThan,i); } } swap(arr,lessThan,pivot); return lessThan; }; const quickSelect=(arr,k,left,right)=>{ let idx=partition(arr,k,left,right); // 一定要注意:idx已经被检查过所以不能再将其加入重新进行检查 if(right-idx+1>k){ return quickSelect(arr,k,idx+1,right); }else if(right-idx+1===k){ return arr[idx]; }else{ return quickSelect(arr,k-(right-idx+1),left,idx-1); } }; const topK=(arr,k)=>{ return quickSelect(arr,k,0,arr.length-1); };
quickSelect
方法 与Quick sort不同的是,Quick select只考虑所寻找的目标所在的那一部分子数组,而非像Quick sort一样分别再对两边进行分 割。正是因为如此,Quick select将平均时间复杂度从O(nlogn)
降到了O(n)
,但与此同时,QuickSelect与QuickSort一样,是一个不稳定的算法;pivot选取直接影响了算法的好坏,worst case下的时间复杂度达到了O(n^2)
关于二分查找法的详解可以参考我的另一篇文章:详解二分查找法
假如 N 个数中最大值是 V m a x V_{max} Vmax 最小值是 V m i n V_{min} Vmin,那么要求的第 k 大的值一定在 [ V m a x V_{max} Vmax , V m i n V_{min} Vmin] ,我们可以利用二分查找对比原数组中 > = m i d >=mid >=mid 的值的个数,利用求 target 右边界的方式求得答案。
const topK2=(arr,k)=>{ let low=Number.MAX_SAFE_INTEGER,high=-Number.MAX_SAFE_INTEGER; for(let i=0;i<arr.length;i++){ if(arr[i]<low){ low=arr[i]; } if(arr[i]>high){ high=arr[i]; } } // count用来计算数组中大于等于 mid 的个数 let cn; while(low<=high){ let mid=Math.floor((low+high)/2); cn=count(arr,mid); // 这里其实是求 target 右边界 if(cn>=k){ // mid is too small low=mid+1; }else{ // mid is too big // 这里可以做层优化,缩减 k 的值,让 count 在 mid 和 high 中寻找大于等于 mid 的数目 high=mid-1; } } return right; }; // ========= 扩展 ========= /** * 如何改造成求最小的第k个值呢,使用求左边界的方式 * @param arr * @param k * @returns {number} */ const topK3=(arr,k)=>{ const count=(arr,target)=>{ let count=0; for(let i=0;i<arr.length;i++){ if(arr[i]<=target){ count++; } } return count; }; let low=Number.MAX_SAFE_INTEGER,high=-Number.MAX_SAFE_INTEGER; for(let i=0;i<arr.length;i++){ if(arr[i]<low){ low=arr[i]; } if(arr[i]>high){ high=arr[i]; } } // count() 用来计算原数组中小于等于 target 的数量 let cn; while(low<=high){ let mid=Math.floor((low+high)/2); cn=count(arr,mid); if(cn>=k){ // mid is too big high=mid-1; }else{ // mid is too small low=mid+1; } } return low; };
O(Nlog(max_val-min_val))
首先这个题就是典型的topk问题,但是用构建最大堆的方式并不能通过,因此这边可以考虑用二分法的方式解决,既然采用二分法套用上面模板即可,但是难点在于如何遍历所有找出来小于mid的所有距离对的个数呢?(遍历所有显然不现实)
此处查找count数目用到了双指针:
arr[right]-arr[left]<=mid
right-left
的值即为小于mid的所有距离对个数。(至于为何是right-left
可以通过自己举例验证比如1到4之间有多少距离对:3+2+1
)const count=(arr,target)=>{ let res=0,left=0; for(let right=0;right<arr.length;right++){ while(arr[right]-arr[left]>target)left++; res+=right-left; } return res; }; const smallestDistancePair1=(arr,k)=>{ arr.sort((a,b)=>a-b); let min=0,max=arr[arr.length-1]-arr[0]; while(min<=max){ let mid=Math.floor((min+max)/2); let cn=count(arr,mid); if(cn>=k){ // mid过大,===k的情况时有可能取的mid也是一个大于目标值的数 max=mid-1; }else{ min=mid+1; } } return min; };
解题思路可以参考我的题解二分法解决or最小堆解决,也可以参考下文
具体可以参考我的另一篇文章:数据结构javascript描述中对于heap
的总结。
假设我们已经构造了一个存有 k 个元素的小根堆,根节点元素就是第 k 大元素也就是后 k 个元素中最小的元素,假如后面继续遍历,有个元素 a 比 根节点还小,假设依然成立,如果比它大那么假设就不成立了,此时将根节点换成 a 并重新 heapify,如此直到遍历完所有元素,得到真正的后 k 个元素组成的小根堆。
Tips:
Topk 问题就用小根堆解决,Lowk 问题就用大根堆解决。
// 注意如果求的不是 topK 而是 lowK 则是用 大根堆
import Heap from './algorithm/Heap/MinHeap';
const topK=(arr,k)=>{
let h=new Heap();
for(let i=0;i<k;i++){
h.insert(arr[i]);
}
for(let i=k;i<arr.length;i++){
if (arr[i]>h.data[0]){
h.deleting();
h.insert(arr[i]);
}
}
return h.data[0];
};
O(NlogK)
其中N
为数组全部长度,K
即为要求的K
(因为堆中元素的数目永远是K
)O(K)
要使用哪种方法解决问题需要根据实际题目做出选择:
最佳方法:二分法。其他方法也可以,但复杂度过高。
其他技巧:灵活应用双指针解决问题
class Solution: def smallestDistancePair(self, nums: List[int], k: int) -> int: """ 优化:count 时因为数组是有序的,除了二分查找还可以使用双指针 时间复杂度:O(n*logD) 其中 D=max(nums)-min(nums) 空间复杂度:O(logn) 主要是排序占据的 :param nums: :param k: :return: """ nums.sort() length = len(nums) def count(d: int): """ 寻找距离小于等于 d 的个数 :param d: :return: """ res = 0 j = 0 for i in range(length): while nums[i] - nums[j] > d: j += 1 res += i - j return res def my_bisect_left(l: int, r: int, target: int, f: Callable) -> int: """ 返回 target 位于原数组最左侧可以插入的位置 :param l: :param r: :param target: :param f: :return: """ while l <= r: mid = (l + r) // 2 if f(mid) < target: l = mid + 1 else: r = mid - 1 return l max_val = nums[-1] - nums[0] # Tip: 因为是要求第 k 个,所以要按小于等于 k 来算 return my_bisect_left(0, max_val, k, count)
可以使用二分法解决,但题目找前 k 个而不是第 k 个,因此需要特殊处理一下
根据题目规律,利用小根堆解决问题
class Solution: def kSmallestPairs_bisect(self, nums1: List[int], nums2: List[int], k: int) -> List[List[int]]: """ 首先读懂题意,从 nums1 nums2 中各取一值,取和排序为前 k 个的索引对。 方法一:穷举所有可能的组合,排序取前 k 个即可,时间复杂度:O(mnlog(mn)) 方法二:已知最小值,最大值,可以用二分法求 lowk 的方式, 时间复杂度:O((m+n)log(max-min)) 空间复杂度:O(k) :param nums1: :param nums2: :param k: :return: """ m, n = len(nums1), len(nums2) min_val, max_val = nums1[0] + nums2[0], nums1[-1] + nums2[-1] def count(target: int) -> int: # [1,7,11] # [2,4,6] 3 # 3+17 <=10 4 # 3+9 <=6 2 # 7+9 <=7 3 j = n - 1 count = 0 for i in range(m): while j >= 0 and nums1[i] + nums2[j] > target: j -= 1 count += j + 1 return count while min_val <= max_val: mid = (min_val + max_val) // 2 if count(mid) < k: min_val = mid + 1 else: max_val = mid - 1 # min_val 即为所求 lowk # print(min_val) # 但由于 min_val 不仅仅可能是 lowk 也可能是 lowk+1 lowk+2 针对这种情景就需要处理 equal 的场景 res, equal = [], [] j = n - 1 for i in range(m): while j >= 0 and nums1[i] + nums2[j] > min_val: j -= 1 for x in range(j + 1): if nums1[i] + nums2[x] == min_val: equal.append([nums1[i], nums2[x]]) else: res.append([nums1[i], nums2[x]]) if len(res) < k: res.extend(equal[0:k - len(res)]) return res def heapify(self, nums: List[List[int]], length: int, i: int): """ 从 i 开始构建最小堆 :param i: :param length: :param nums: :return: """ l, r = i * 2 + 1, i * 2 + 2 min_idx = i if l < length and nums[l][2] < nums[min_idx][2]: min_idx = l if r < length and nums[r][2] < nums[min_idx][2]: min_idx = r if min_idx != i: nums[i], nums[min_idx] = nums[min_idx], nums[i] self.heapify(nums, length, min_idx) def build_min_heap(self, nums: List[List[int]]): length = len(nums) # 最后一个非叶子节点所在索引 idx = length // 2 - 1 for i in range(idx, -1, -1): self.heapify(nums, length, i) def heap_extract_min(self, nums: List[List[int]]) -> List[int]: min_val = nums[0] length = len(nums) nums[0], nums[length - 1] = nums[length - 1], nums[0] nums.pop() self.heapify(nums, length - 1, 0) return min_val def heap_decrease(self, nums: List[List[int]], i: int, val: List[int]): """ 索引 i 处 修改为 val :param nums: :param i: :param val: :return: """ # 根节点 def get_root(idx: int) -> int: return (idx + 1) // 2 - 1 nums[i] = val while get_root(i) >= 0 and nums[get_root(i)][2] > nums[i][2]: nums[get_root(i)], nums[i] = nums[i], nums[get_root(i)] i = get_root(i) def heap_push(self, nums: List[List[int]], val: List[int]): nums.append(val) length = len(nums) self.heap_decrease(nums, length - 1, val) def kSmallestPairs(self, nums1: List[int], nums2: List[int], k: int) -> List[List[int]]: """ 方法三:先选 k 个值构建大根堆,然后遍历所有元素将其与 heap[0] 对比,得到答案,但这种方式需要遍历 mn 个元素,如果 nums1 nums2 数组较大,显然效率较低。但我们可以根据题目的规律构建小根堆来解决问题: 已知最小的是 (0,0), 下一个就是待比较的就是 (0,1) 和 (1,0) 假如下一个是 (0,1) 那么下一个要比较的是 (1,0) (1,1) (0,2) => (1,0) (0,2) 假如下一个是 (1,0) 那么下一个要比较的是 (0,1) (1,1) (2,0) => (0,1) (2,0) 假如下一个是 (1,0) 那么下一个要比较的是 (0,2) (1,1) (2,0) 其实就是上次比较的数,加上新的 (a+1,b) (a,b+1),但是每次都这么增加其实带有重复的情况,比如 (0,0) (0,1) (1,0) 选 (0,1) 则 (1,0) 再加入 (1,1) (0,2) 选 (1,0) 则 (1,1) (0,2) 再加入 (1,1) (2,0) 此时重复了 (1,1) 如果一开始我们就有 (0,0) (1,0) (2,0)...(k-1,0), 找到最小值 (a,b) 之后每次都添加 (a,b+1) 则变成了 (0,0) (1,0)... 再加入 (0,1) 选 (0,1) 则 (1,0) (2,0) ... 再加入 (0,2) 选 (1,0) 则 (2,0) (0,2) ... 再加入 (1,1) 满足需求 建堆:O(N) extract_min: O(logN) heap_push: O(logN) 时间复杂度:O(klogk) 空间复杂度:O(k) :param nums1: :param nums2: :param k: :return: """ m, n = len(nums1), len(nums2) nums = [[i, 0, nums1[i] + nums2[0]] for i in range(min(k, m))] self.build_min_heap(nums) res = [] while nums and len(res) < k: [i, j, _] = self.heap_extract_min(nums) res.append([nums1[i], nums2[j]]) if j + 1 < n: self.heap_push(nums, [i, j + 1, nums1[i] + nums2[j + 1]]) return res
最佳方法是使用二分法获得 lowk
其他技巧:z 形搜索(主要针对横纵均有序的矩阵,利用第一行最后一个元素为基准进行搜索的算法)
class Solution: def kthSmallest_0(self, matrix: List[List[int]], k: int) -> int: """ 首先读懂题意,已知矩阵中每一行和列是增序的,求矩阵中所有元素的第 k 小的元素 方法一:直接对 n*n 个数字进行排序,时间复杂度 n*n*log(n*n) = O(2n^2logn) 方法二:quick select 期望时间复杂度为 O(n^2) 方法三:由于矩阵中元素其实是有序的,可以考虑使用二分法,得到 max_val, min_val,得到 mid 求小于等于 mid 的值的数量 时间复杂度:O(Nlog(max_val-min_val)) 空间复杂度:O(1) :param matrix: :param k: :return: """ n = len(matrix) min_val, max_val = matrix[0][0], matrix[n - 1][n - 1] def count(target: int, n: int) -> int: """ 寻找小于等于 target 的个数 eg: 在 [[1,5,9],[10,11,13],[12,13,15]] 中寻找小于等于 8 的个数 1 5 9 10 11 13 12 13 15 如果使用 z 字形搜索,l=0,r=n-1,选这个点作为对比点,比它大的值肯定在下一行,比它小的则可以继续在该行左移,其他的所有点都类似, 因此可以利用 z 字形搜索找到要找的答案。 时间复杂度:O(2n) 即 O(n) 空间复杂度:O(1) <= 12 的也利用 z 字形搜索,有 3+2+1=6 <=14,有 3+3+2=8 此时要缩减一下 <=13,有 3+3+2=8 此时再缩减一下,13+12//2=12 <=12,有 3+2+1=6 left=mid+1,变成 13,得到结果 :param target: :return: """ left, right = 0, n - 1 res = 0 while left < n and right >= 0: if matrix[left][right] <= target: res += right + 1 left += 1 else: right -= 1 return res while min_val <= max_val: mid = (min_val + max_val) // 2 # 注意这里是找 bisect_left if count(mid, n) < k: min_val = mid + 1 else: max_val = mid - 1 return min_val def heapify(self, nums: List[int], length: int, idx: int): """ max-heap heapify :param nums: :param length: :param idx: :return: """ max_idx = idx l = idx * 2 + 1 r = idx * 2 + 2 if l < length and nums[l] > nums[max_idx]: max_idx = l if r < length and nums[r] > nums[max_idx]: max_idx = r if max_idx != idx: nums[max_idx], nums[idx] = nums[idx], nums[max_idx] self.heapify(nums, length, max_idx) def build_max_heap(self, nums: List[int]): length = len(nums) last = length // 2 - 1 for i in range(last, -1, -1): self.heapify(nums, length, i) def heap_change_val(self, nums: List[int], idx: int, val: int): nums[idx] = val if idx == 0: self.heapify(nums, len(nums), 0) def kthSmallest(self, matrix: List[List[int]], k: int) -> int: """ 方法四:构建大根堆,审查所有元素,最后得到的根即为第 k 小的元素,时间复杂度 O(N^2logk) :param matrix: :param k: :return: """ n = len(matrix) length = min(k, n * n) nums = [] total = 0 stop_i, stop_j = 0, 0 for i in range(n): for j in range(n): if total >= length: break nums.append(matrix[i][j]) stop_i, stop_j = i, j total += 1 # 构建 max-heap 时间复杂度 O(k) self.build_max_heap(nums) # print('====>', nums, stop_i, stop_j) for j in range(stop_j+1, n): if nums[0] > matrix[stop_i][j]: self.heap_change_val(nums, 0, matrix[stop_i][j]) for i in range(stop_i + 1, n): for j in range(n): if nums[0] > matrix[i][j]: self.heap_change_val(nums, 0, matrix[i][j]) return nums[0]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。