数据结构-BFPRT算法

背景

在一大堆数中求其前k大或前k小的问题,简称TOP-K问题。而目前解决TOP-K问题最有效的算法即是BFPRT算法,其又称为中位数的中位数算法,该算法由Blum、Floyd、Pratt、Rivest、Tarjan提出,最坏时间复杂度为$O(n)$。

在首次接触TOP-K问题时,我们的第一反应就是可以先对所有数据进行一次排序,然后取其前k即可,但是这么做有两个问题:
(1)快速排序的平均复杂度为$O(nlogn)$,但最坏时间复杂度为$O(n^2)$,不能始终保证较好的复杂度。
(2)我们只需要前k大的,而对其余不需要的数也进行了排序,浪费了大量排序时间。

除这种方法之外,堆排序也是一个比较好的选择,可以维护一个大小为k的堆,时间复杂度为$O(nlogk)$。

那是否还存在更有效的方法呢?受到快速排序的启发,通过修改快速排序中主元的选取方法可以降低快速排序在最坏情况下的时间复杂度(即BFPRT算法),并且我们的目的只是求出前k,故递归的规模变小,速度也随之提高。下面来简单回顾下快速排序的过程,以升序为例:
(1)选取主元(首元素,尾元素或一个随机元素);
(2)以选取的主元为分界点,把小于主元的放在左边,大于主元的放在右边;
(3)分别对左边和右边进行递归,重复上述过程。

快速排序代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Solution(object):
def main(self, nums, k):

self.quickSort(nums, 0, len(nums)-1)
return nums[k]

def quickSort(self, nums, left, right):
if left >= right:
return

povit = self.partition(nums, left, right)
self.quickSort(nums, left, povit-1)
self.quickSort(nums, povit+1, right)

def partition(self, nums, left, right):
temp = nums[left]
while left < right:
while left < right and nums[right] >= temp:
right -= 1
self.swap(nums, left, right)
while left < right and nums[left] <= temp:
left += 1
self.swap(nums, left, right)
return left

def swap(self, nums, left, right):
temp = nums[left]
nums[left] = nums[right]
nums[right] = temp


if __name__ == '__main__':
res = Solution().main([3,4,1,2,5,7,23,4,1,5], 4)
print(res)

BFPRT算法

BFPRT算法步骤如下:
(1)选取主元;
  (1.1)将n个元素划分为$n/5$个组,每组5个元素,最后不足5个的为一组;
  (1.2)使用插入排序找到$n/5$个组中每一组的中位数;
  (1.3)对于(1.2)中找到的所有中位数,调用BFPRT算法求出它们的中位数,作为主元;
(2)以(1.3)选取的主元为分界点,把小于主元的放在左边,等于主元的放中间,大于主元的放在右边,返回等于主元的左右边界;
(3)判断主元的位置与k的大小,有选择的对左边或右边递归。

下面为代码实现,其所求为前K小的数:
注意:
在partition()中,需要首先将主元和最左边的元素进行交换。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class Solution(object):
def main(self, nums, k):
"""
:type nums: List[int]
:type k: int
:rtype: int
"""
return self.bfprt(nums, 0, len(nums)-1, k-1)

def bfprt(self, nums, left, right, k):
if left >= right:
return nums[left]

povitValue = self.medianOfMedians(nums, left, right)
povitRange = self.partition(nums, left, right, povitValue) # 等于区域的左右边界
if k >= povitRange[0] and k <= povitRange[1]: # k 在左右边界内 直接返回
return nums[k]
elif k < povitRange[0]: # k在左边
return self.bfprt(nums, left, povitRange[0]-1, k)
else: # k在右边
return self.bfprt(nums, povitRange[1]+1, right, k)


def medianOfMedians(self, nums, left, right):
len_ = right-left + 1
offset = 0 if len_%5 == 0 else 1
mArr = [0]*(len_//5+offset)
for i in range(len(mArr)):
left_I = left + i*5
right_I = left_I + 4
mArr[i] = self.getMedian(nums, left_I, min(right, right_I))

return self.bfprt(mArr, 0, len(mArr)-1, len(mArr)//2)

def getMedian(self, nums, left, right):
for i in range(left+1, right+1):
j = i
temp = nums[i]
while j-1 >= left and nums[j-1] > temp:
nums[j] = nums[j-1]
j -= 1
nums[j] = temp
return nums[(left+right)//2 + (left+right)%2]

def partition(self, nums, left, right, povitValue):
index = nums.index(povitValue)
self.swap(nums, left, index)
while left < right:
while left < right and nums[right] >= povitValue:
right -= 1
self.swap(nums, left, right)
while left < right and nums[left] <= povitValue:
left += 1
self.swap(nums, left, right)
return [left, right]

def swap(self, nums, left, right):
temp = nums[left]
nums[left] = nums[right]
nums[right] = temp

if __name__ == '__main__':
res = Solution().main([7,6,5,4,3,2,1], 5)
print(res)

时间复杂度分析

BFPRT算法在最坏情况下的时间复杂度是$O(n)$,下面予以证明。令$T(n)$为所求的时间复杂度,则有:

$T(n)≤T(\frac{n}{5})+T(\frac{7n}{10})+c⋅n$ (c为一个正常数)

其中:

  • $T(\frac{n}{5})$来自medianOfMedians(),n个元素,5个一组,共有$\frac{n}{5}$个中位数;
  • $T(\frac{7n}{10})$来自bfprt(),在$\frac{n}{5}$个中位数中,主元x大于其中 $\frac{1}{2}\frac{n}{5} = \frac{n}{10}$的中位数,而每个中位数在其本来的5个数的小组中又大于或等于其中的3个数,所以主元x至少大于所有数中的 $\frac{n}{10}3=\frac{3n}{10}$个。即划分之后,任意一边的长度至少为$\frac{3}{10}$,在最坏情况下,每次选择都选到了$\frac{7}{10}$的那一部分。
  • $c⋅n$来自其它操作,比如getMedian()中的插入排序,以及medianOfMedians()和partition()里所需的一些额外操作。

设$T(n)=t \cdot n$,其中t为未知,它可以是一个正常数,也可以是一个关于n的函数,代入上式:

$t \cdot n \leq \frac{t \cdot n}{5}+\frac{7 t \cdot n}{10}+c \cdot n$ (两边消去n)
$t \leq \frac{t}{5}+\frac{7 t}{10}+c$ (再化简)
$t \leq \frac{9t}{10}+ c$ (c为一个正常数)
$t \leq 10 c$

其中c为一个正常数,故t也是一个正常数,即$T(n) \leq 10 c \cdot n$ (c为一个正常数),因此$T(n)=O(n)$,至此证明结束。

例题: leetcode 215. 数组中的第K个最大元素

题目描述

在未排序的数组中找到第 k 个最大的元素。请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

  • 示例 1:

    输入: [3,2,1,5,6,4] 和 k = 2
    输出: 5

  • 示例 2:

    输入: [3,2,3,1,2,4,5,5,6] 和 k = 4
    输出: 4

  • 说明:

    你可以假设 k 总是有效的,且 1 ≤ k ≤ 数组的长度。

解题思路

可以使用快排,堆排,归并排序等算法进行解题,但平均复杂度为$O(nlogn)$。

以下代码为使用BFPRT算法解决,复杂度为$O(n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class Solution(object):
def findKthLargest(self, nums, k):
"""
:type nums: List[int]
:type k: int
:rtype: int
"""
if len(nums) < k or k <= 0:
return 0

return self.bfprt(nums, 0, len(nums)-1, k-1)

def bfprt(self, nums, left, right, k):
if left >= right:
return nums[left]

povitValue = self.medianOfMedians(nums, left, right)
povitRange = self.partition(nums, left, right, povitValue) # 等于区域的左右边界
if k >= povitRange[0] and k <= povitRange[1]: # k 在左右边界内 直接返回
return nums[k]
elif k < povitRange[0]: # k在左边
return self.bfprt(nums, left, povitRange[0]-1, k)
else: # k在右边
return self.bfprt(nums, povitRange[1]+1, right, k)

def medianOfMedians(self, nums, left, right):
len_ = right-left + 1
offset = 0 if len_%5 == 0 else 1
mArr = [0]*(len_//5+offset)
for i in range(len(mArr)):
left_I = left + i*5
right_I = left_I + 4
mArr[i] = self.getMedian(nums, left_I, min(right, right_I))
return self.bfprt(mArr, 0, len(mArr)-1, len(mArr)//2)

def getMedian(self, nums, left, right):
self.insertSort(nums, left, right)
return nums[(left+right)//2 + (left+right)%2]

def partition(self, nums, left, right, povitValue):
index = nums.index(povitValue)
self.swap(nums, left, index)
while left < right:
while left < right and nums[right] <= povitValue:
right -= 1
self.swap(nums, left, right)
while left < right and nums[left] >= povitValue:
left += 1
self.swap(nums, left, right)
return [left, right]

def swap(self, nums, left, right):
temp = nums[left]
nums[left] = nums[right]
nums[right] = temp

def insertSort(self, nums, left, right):
for i in range(left+1, right+1):
j = i
temp = nums[i]
while j-1 >= left and nums[j-1] < temp:
nums[j] = nums[j-1]
j -= 1
nums[j] = temp

参考文献:
https://en.wikipedia.org/wiki/Median_of_medians
https://blog.csdn.net/laojiu_/article/details/54986553

Donate comment here
------------The End------------
0%