[LeetCode] 分治之 Median of Two Sorted Arrays 课后题算法实现 Hard

写在前面

这道题目是我曾经在LeetCode上就见到过的,是很少的限定了时间复杂度的题目之一。这道题目的难度可以说是蛮高的,Hard 22.1%的通过率,可以说是在leetcode的所有题目来说也是非常难的了。当时学习分治法的时候曾经看过题目,也看了discuss但是并没有很好的思路,在O( log(m+n) )的时间复杂度来说还是算比较难的。后来在老师讲课后习题的时候,恰好在算法概论第二章的一题中提到了这个类似的题目,要求找到两个有序数组中第K大的数,并且要求的时间复杂度是O( log(m)+log(n) ),时间复杂度相比起来会更低。当时我在课上提出了一个符合题目时间复杂度的思路,而且也基本可行,时间复杂度也符合,但是因为终止条件什么的都还没有提出,所以实现的话还需要另外丰富这个算法。于是课后就专门找了这题来实现。


题目

There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
Example 1 :

nums1 = [1, 3]
nums2 = [2]

The median is 2.0

Example 2 :

nums1 = [1, 2]
nums2 = [3, 4]

The median is (2 + 3)/2 = 2.5

分析

当时我在课上提出的算法是针对于任意给定的k(k < nums1.size() + nums2.size())的,对于这题来说,k是中位数,也就是k = (nusm1.size() + nums2.size())/2,这只是一种特殊情况,下面我将对于任意给定的k来进行设计算法。由于时间比较紧张,可能有更简单的,但是我没有考虑那么多的优化。我的思路是递归的,但是由于是尾递归,可以在实现的时候转成循环,这种转换是简单的。

算法流程

首先将两个array一个命名为A,另一个为B,A的size是m,B的size是n

  1. 如果A或B是空的,那么直接返回另一个数组的中位数
  2. 如果发现其中某个数组的最大值小于另外一个数组的最小值,那么直接求得中位数
  3. 如果发现其中某个数组的size == 1,那么直接将其二分查找插入另一个数组,并求得中位数
  4. 比较A[m/2]和B[n/2]的大小,假设A[m/2] >= B[n/2]
    4.1. 如果k<=(m+n)/2,则去掉A的后半部分,并回到步骤2
    4.2 如果k>(m+n)/2,则去掉B的前半部分,并将k赋值为k-n/2,并回到步骤2
  5. 比较A[m/2]和B[n/2]的大小,假设A[m/2] <= B[n/2]
    5.1 如果k<=(m+n)/2,则去掉B的后半部分,并回到步骤2
    5.2 如果k>(m+n)/2,则去掉A的前半部分,并将k赋值为k-m/2,并回到步骤2

实际题目会因为奇偶数,导致中位数的求法需要另外考虑,但是这并不要紧,只是稍微麻烦了一些而已。

时间复杂度分析

时间复杂度之前已经说过了,是O( log(n)+log(m) ),这是易于发现的,因为每一轮递归都至少有一个数组的求解范围会缩短一半,除非遇到终止条件。

代码

这里给出了递归和非递归的两种实现。实现的时候千万要注意的一点就在于各种边界的+1,-1,以及不要忘记我们只是缩小了下标的范围,并没有真实的缩小两个数组,因此在比较的时候必须注意什么时候需要加上al、bl,由于一直是静态debug,所以这种错误有时候会因为思路是跟着算法走的,实际上实现的时候就出现了纰漏,最后花了很久时间在找这两个错误上。

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        a = nums1;
        b = nums2;
        isOdd = (nums1.size() + nums2.size()) % 2;
        k = (nums1.size() + nums2.size()) >> 1;
        // consider the condition that one of the two arrays is empty
        if (nums1.size() == 0) {
            if (isOdd) ans = nums2[k];
            else ans = 1.0*(nums2[k -1] + nums2[k]) /2;
            return ans;
        } else if (nums2.size() == 0) {
            if (isOdd) ans = nums1[k];
            else ans = 1.0*(nums1[k -1] + nums1[k]) /2;
            return ans;
        }
        test(0, nums1.size() -1, 0, nums2.size() -1);
        return ans;
    }
    void test(int al, int ar, int bl, int br) {
        // the end conditon
        // the whole array a smaller than array b
        int m = ar - al +1, n = br - bl+1;
        if (a[ar] <= b[bl]) {
            if (isOdd) ans = m > k ? a[al + k] : b[bl + k - m];
            else {
                int left = m > k -1 ? a[al + k -1] : b[bl + k -1 - m];
                int right = m > k ? a[al + k] : b[bl + k - m];
                ans = 1.0*(left + right) / 2;
            }
            return;
        }
        // the whole array b smaller than array a
        else if (b[br] <= a[al]) {
            if (isOdd) ans = n > k ? b[bl + k] : a[al + k - n];
            else {
                int left = n > k -1 ? b[bl + k -1] : a[al + k -1- n];
                int right = n > k ? b[bl + k] : a[al + k -n];
                ans = 1.0*(left + right) / 2;
            }
            return;
        }
        if (ar - al == 0) {
            vector<int> temp(b.begin() + bl, b.begin() + br +1);
            temp.push_back(a[al]);
            sort(temp.begin(), temp.end());
            if (isOdd) ans = temp[k];
            else ans = 1.0*(temp[k -1] + temp[k]) /2;
            return;
        } else if (br - bl == 0) {
            vector<int> temp(a.begin() + al, a.begin() + ar +1);
            temp.push_back(b[bl]);
            sort(temp.begin(), temp.end());
            if (isOdd) ans = temp[k];
            else ans = 1.0*(temp[k -1] + temp[k]) /2;
            return;
        }
        // m is the length of array a, n is the length of array b
        if (a[al + m/2] >= b[bl + n/2]) {;
            if (k <= ((m + n)>>1)) {
                ar = ar - m/2;
            } else {
                bl = bl + n/2;
                k = k - n/2;
            }
        } else {
            if (k <= ((m + n) >> 1)) {
                br = br - n/2;
            } else {
                al = al + m/2;
                k = k - m/2;
            }
        }
        test(al,ar,bl,br);
    }

    void testWithoutRecursion(int al, int ar, int bl, int br) {
        while (1) {
            // the end conditon
            // the whole array a smaller than array b
            int m = ar - al +1, n = br - bl+1;
            if (a[ar] <= b[bl]) {
                if (isOdd) ans = m > k ? a[al + k] : b[bl + k - m];
                else {
                    int left = m > k -1 ? a[al + k -1] : b[bl + k -1 - m];
                    int right = m > k ? a[al + k] : b[bl + k - m];
                    ans = 1.0*(left + right) / 2;
                }
                return;
            }
            // the whole array b smaller than array a
            else if (b[br] <= a[al]) {
                if (isOdd) ans = n > k ? b[bl + k] : a[al + k - n];
                else {
                    int left = n > k -1 ? b[bl + k -1] : a[al + k -1- n];
                    int right = n > k ? b[bl + k] : a[al + k -n];
                    ans = 1.0*(left + right) / 2;
                }
                return;
            }
            if (ar - al == 0) {
                vector<int> temp(b.begin() + bl, b.begin() + br +1);
                temp.push_back(a[al]);
                sort(temp.begin(), temp.end());
                if (isOdd) ans = temp[k];
                else ans = 1.0*(temp[k -1] + temp[k]) /2;
                return;
            } else if (br - bl == 0) {
                vector<int> temp(a.begin() + al, a.begin() + ar +1);
                temp.push_back(b[bl]);
                sort(temp.begin(), temp.end());
                if (isOdd) ans = temp[k];
                else ans = 1.0*(temp[k -1] + temp[k]) /2;
                return;
            }
            // m is the length of array a, n is the length of array b
            if (a[al + m/2] >= b[bl + n/2]) {
                if (k <= ((m + n)>>1)) {
                    ar = ar - m/2;
                } else {
                    bl = bl + n/2;
                    k = k - n/2;
                }
            } else {
                if (k <= ((m + n) >> 1)) {
                    br = br - n/2;
                } else {
                    al = al + m/2;
                    k = k - m/2;
                }
            }
        }
    }

    vector<int> a;
    vector<int> b;
    int k;
    double ans;
    // 0 means even, 1 means odd
    bool isOdd;
};

猜你喜欢

转载自blog.csdn.net/qq_34035179/article/details/78220013
今日推荐