【算法】牛客网算法进阶班(BFPRT算法(TOP-K问题))

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/ARPOSPF/article/details/82932307

【算法】牛客网算法进阶班(BFPRT算法(TOP-K问题))


BFPRT算法(TOP-K问题)

一:题目描述

在一大堆数中求出其前k大或前k小的所有数,简称TOP-K问题。目前解决TOP-K问题最有效的算法即是BFPRT算法,又称为中位数的中位数算法,该算法由Blum、Floyd、Pratt、Rivest、Tarjan提出,最坏时间复杂度为O(n)。

首次遇到接触TOP-K问题时,第一反应是可以先对所有数据进行一次排序,然后取其前k即可,但是这么做有两个问题:
(1):快速排序的平均复杂度为O(nlogn),但最坏时间复杂度为O(n^2),不能始终保证较好的复杂度。
(2):我们只需要前k大的,而对其余不需要的数也进行了排序,浪费了大量排序时间。

除上述方法外,堆排序也是一个比较好的选择,可以维护一个大小为k的堆,时间复杂度为O(nlogk)。

那是否还存在更有效的方法呢?受到快速排序的启发,通过修改快速排序中基准元素的选取方法可以降低快速排序在最坏情况下的时间复杂度(即BFPRT算法),并且我们的目的只是求出前k,故递归的规模变小,速度也随之提高。下面来简单回顾下快速排序的过程,以升序为例:
(1):选取基准元素(首元素,尾元素或一个随机元素);
(2):以选取的基准元素为分界点,把小于基准元素的放在左边,大于基准元素的放在右边;
(3):分别对左边和右边进行递归,重复上述过程。 

二、BFPRT算法过程

BFPRT算法步骤如下:

  1. 选取基准元素;
    1. 将n个元素每5个一组,分成n/5(上界)组,最后的一个组的元素个数为n%5,有效的组数为n/5。
    2. 取出每一组的中位数,最后一个组的不用计算中位数,任意排序方法,这里的数据比较少只有5个,可以用简单的冒泡排序或是插入排序。
    3. 对于第1.2中找到的所有中位数,调用BFPRT算法求出它们的中位数,作为基准元素,设为x,偶数个中位数的情况下设定为选取中间小的一个。
  2. 以1.3中选取的基准元素作为分割点,将小于基准元素的放在左边,个数为k个,大于或等于基准元素的放在右边,个数为n-k。
  3. 判断基准元素位置i与k的大小
    1. 如果i==k,返回x;
    2. 如果i<k,在小于x的元素中递归查找第i小的元素;
    3. 如果i>k,在大于等于x的元素中递归查找第i-k小的元素。

BFPRT()调用GetPivotIndex()和Partition()来求解第k小,在这过程中,GetPivotIndex()也调用了BFPRT(),即GetPivotIndex)和BFPRT()为互递归的关系。

三:算法代码

代码如下,求前k小的数

package NowCoder2.Class01;

public class BFPRT {
    public static void main(String[] args) {
        int[] arr = {6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9};
        // sorted : { 1, 1, 1, 1, 2, 2, 2, 3, 3, 5, 5, 5, 6, 6, 6, 7, 9, 9, 9 }
        printArray(getMinKNumsByHeap(arr, 10));//通过堆排方式得到Top-K元素
        printArray(getMinKNumsByQuick(arr, 10));//通过快排方式得到Top-K元素
        printArray(getMinKNumsByBFPRT(arr, 10));//通过BFPRT算法得到Top-K元素
    }

    /**
     * 堆排解法,时间复杂度为O(N*logK)
     *
     * @param arr
     * @param i
     * @return
     */
    private static int[] getMinKNumsByHeap(int[] arr, int k) {
        if (k < 1 || k > arr.length) {
            return arr;
        }
        int[] kHeap = new int[k];
        for (int i = 0; i != k; i++) {
            heapInsert(kHeap, arr[i], i);
        }
        for (int i = k; i != arr.length; i++) {
            if (arr[i] < kHeap[0]) {
                kHeap[0] = arr[i];
                heapify(kHeap, 0, k);
            }
        }
        return kHeap;
    }

    private static void heapInsert(int[] arr, int value, int index) {
        arr[index] = value;
        while (index != 0) {
            int parent = (index - 1) / 2;
            if (arr[parent] < arr[index]) {
                swap(arr, parent, index);
                index = parent;
            } else {
                break;
            }
        }
    }

    private static void heapify(int[] arr, int index, int heapSize) {
        int left = index * 2 + 1;
        int right = index * 2 + 2;
        int largest = index;
        while (left < heapSize) {
            if (arr[left] > arr[index]) {
                largest = left;
            }
            if (right < heapSize && arr[right] > arr[largest]) {
                largest = right;
            }
            if (largest != index) {
                swap(arr, largest, index);
            } else {
                break;
            }
            index = largest;
            left = index * 2 + 1;
            right = index * 2 + 2;
        }
    }

    /**
     * 通过快排的方式,时间复杂度为O(N)
     *
     * @param arr
     * @param k
     * @return
     */
    private static int[] getMinKNumsByQuick(int[] arr, int k) {
        if (arr != null && arr.length > 0) {
            int low = 0;
            int high = arr.length - 1;
            int index = partition(arr, low, high);
            //不断调整分治思想,直到position=k-1
            while (index != k - 1) {
                //大了,往前调整
                if (index > k - 1) {
                    high = index - 1;
                    index = partition(arr, low, high);
                }
                //小了,往后调整
                if (index < k - 1) {
                    low = index + 1;
                    index = partition(arr, low, high);
                }
            }
        }
        int[] res = new int[k];
        for (int i = 0; i < res.length; i++) {
            res[i] = arr[i];
        }
        return res;
    }

    private static int partition(int[] arr, int low, int high) {
        if (arr != null && low < high) {
            int flag = arr[low];
            while (low < high) {
                while (low < high && arr[high] >= flag) {
                    high--;
                }
                arr[low] = arr[high];
                while (low < high && arr[low] <= flag) {
                    low++;
                }
                arr[high] = arr[low];
            }
            arr[low] = flag;
            return low;
        }
        return 0;
    }

    /**
     * 通过BRPRT算法获得Top-K问题的解,时间复杂度为O(N)
     *
     * @param arr
     * @param k
     */
    private static int[] getMinKNumsByBFPRT(int[] arr, int k) {
        if (k < 1 || k > arr.length) {
            return arr;
        }
        int minKth = getMinKthByBFPRT(arr, k);
        int[] res = new int[k];
        int index = 0;
        for (int i = 0; i < arr.length; i++) {
            if (arr[i] < minKth) {
                res[index++] = arr[i];
            }
        }
        for (; index < res.length; index++) {
            res[index] = minKth;
        }
        return res;
    }

    private static int getMinKthByBFPRT(int[] arr, int k) {
        int[] copyArr = copyArray(arr);
        return select(copyArr, 0, copyArr.length - 1, k - 1);
    }

    private static int[] copyArray(int[] arr) {
        int[] res = new int[arr.length];
        for (int i = 0; i < res.length; i++) {
            res[i] = arr[i];
        }
        return res;
    }

    private static int select(int[] arr, int begin, int end, int i) {
        if (begin == end) {
            return arr[begin];
        }
        int pivot = medianOfMedians(arr, begin, end);
        int[] pivotRange = partition(arr, begin, end, pivot);
        if (i >= pivotRange[0] && i <= pivotRange[1]) {
            return arr[i];
        } else if (i < pivotRange[0]) {
            return select(arr, begin, pivotRange[0] - 1, i);
        } else {
            return select(arr, pivotRange[1] + 1, end, i);
        }
    }

    private static int medianOfMedians(int[] arr, int begin, int end) {
        int num = end - begin + 1;
        int offset = num % 5 == 0 ? 0 : 1;
        int[] mArr = new int[num / 5 + offset];
        for (int i = 0; i < mArr.length; i++) {
            int beginI = begin + i * 5;
            int endI = beginI + 4;
            mArr[i] = getMedian(arr, beginI, Math.min(end, endI));
        }
        return select(mArr, 0, mArr.length - 1, mArr.length / 2);
    }

    private static int[] partition(int[] arr, int begin, int end, int pivotValue) {
        int small = begin - 1;
        int cur = begin;
        int big = end + 1;
        while (cur != big) {
            if (arr[cur] < pivotValue) {
                swap(arr, ++small, cur++);
            } else if (arr[cur] > pivotValue) {
                swap(arr, cur, --big);
            } else {
                cur++;
            }
        }
        int[] range = new int[2];
        range[0] = small + 1;
        range[1] = big - 1;
        return range;
    }

    private static int getMedian(int[] arr, int begin, int end) {
        insertionSort(arr, begin, end);
        int sum = end + begin;
        int mid = (sum / 2) + (sum % 2);
        return arr[mid];
    }

    private static void insertionSort(int[] arr, int begin, int end) {
        for (int i = begin + 1; i != end + 1; i++) {
            for (int j = i; j != begin; j--) {
                if (arr[j - 1] > arr[j]) {
                    swap(arr, j - 1, j);
                } else {
                    break;
                }
            }
        }
    }

    /**
     * 公共方法,交换数据和打印数据
     * @param arr
     * @param i
     * @param j
     */
    private static void swap(int[] arr, int i, int j) {
        int temp = arr[i];
        arr[i] = arr[j];
        arr[j] = temp;
    }

    private static void printArray(int[] arr) {
        for (int i = 0; i < arr.length; i++) {
            System.out.print(arr[i] + " ");
        }
        System.out.println();
    }
}

相关求解方法,如堆排解法和块排解法也一并列在上述代码中。

四:其他解法

堆排解法

用堆排来解决Top K的思路很直接。堆排利用的大(小)顶堆所有子节点元素都比父节点小(大)的性质来实现的,既然一个大顶堆的顶是最大的元素,那我们要找最小的K个元素,是不是可以先建立一个包含K个元素的堆,然后遍历集合,如果集合的元素比堆顶元素小(说明它目前应该在K个最小之列),那就用该元素来替换堆顶元素,同时维护该堆的性质,那在遍历结束的时候,堆中包含的K个元素是不是就是我们要找的最小的K个元素? 

速记口诀:最小的K个用最大堆,最大的K个用最小堆。

时间复杂度:O(n*logK)

适用场景:实现的过程中,先用前K个数建立了一个堆,然后遍历数组来维护这个堆。这种做法带来了三个好处:

(1)不会改变数据的输入顺序(按顺序读的);

(2)不会占用太多的内存空间(事实上,一次只读入一个数,内存只要求能容纳前K个数即可);

(3)由于(2),决定了它特别适合处理海量数据。

快排解法

用快排的思想来解Top K问题,必然要运用到”分治”。

与快排相比,两者唯一的不同是在对”分治”结果的使用上。我们知道,分治函数会返回一个position,在position左边的数都比第position个数小,在position右边的数都比第position大。我们不妨不断调用分治函数,直到它输出的position = K-1,此时position前面的K个数(0到K-1)就是要找的前K个数。

时间复杂度:O(n)

适用场景:对照着堆排的解法来看,partition函数会不断地交换元素的位置,所以它肯定会改变数据输入的顺序;既然要交换元素的位置,那么所有元素必须要读到内存空间中,所以它会占用比较大的空间,至少能容纳整个数组;数据越多,占用的空间必然越大,海量数据处理起来相对吃力

但是,它的时间复杂度很低,意味着数据量不大时,效率极高。

猜你喜欢

转载自blog.csdn.net/ARPOSPF/article/details/82932307