题目描述
- 输入n个整数,找出其中最小的k个数。例如输入4、5、1、6、2、7、3、8这8个数字,则最小的4个数字是1、2、3、4。
算法分析
- 法一:数组排序后,找到前k个数,快速排序时间复杂度O(nlogn)。
- 法二:最大堆法,采用基于红黑树实现的最大堆,红黑树插入删除的时间复杂度为logk,总体时间复杂度为O(nlogk)。
提交代码:
class Solution {
public:
/* 先排序,再求解 */
vector<int> GetLeastNumbers_Solution(vector<int> input, int k) {
if (input.size() < k || input.empty() || k <= 0)
return vector<int>();
int length = input.size() - 1;
QuickSort(input, 0, length - 1);
vector<int> result;
for (int i = 0; i < k; ++i)
result.push_back(input[i]);
return result;
}
void QuickSort(vector<int> &input, int beg, int end)
{
if (beg < end)
{
int index = Partition(input, beg, end);
QuickSort(input, beg, index - 1);
QuickSort(input, index + 1, end);
}
}
int Partition(vector<int> &input, int beg, int end)
{
/* 简单取第一个数作为中枢值 */
int piviotkey = input[beg];
while (beg < end)
{
while (beg < end && input[end] >= piviotkey)
--end;
swap(input[beg], input[end]);
while (beg < end && input[beg] <= piviotkey)
++beg;
swap(input[beg], input[end]);
}
return beg;
}
/* 采用红黑树实现的最大堆,红黑树插入删除的时间复杂度为logk */
vector<int> GetLeastNumbers_Solution2(vector<int> input, int k) {
/* set及multiset均为红黑树实现 greater<int>在functional头文件中*/
if (input.size() < k || input.empty() || k <= 0)
return vector<int>();
multiset<int, greater<int>> maxStack;
int i = 0;
for (; i < k; ++i)
maxStack.insert(input[i]);
for (; i < input.size(); ++i)
{
if (input[i] < *maxStack.begin())
{
maxStack.erase(maxStack.begin());
maxStack.insert(input[i]);
}
}
vector<int> result;
for (auto iter = maxStack.rbegin(); iter != maxStack.rend(); ++iter)
result.push_back(*iter);
return result;
}
};
测试代码:
// ====================测试代码====================
void Test(char* testName, vector<int> data, vector<int> expectedResult, int k)
{
if (testName != nullptr)
printf("%s begins: \n", testName);
if (expectedResult.empty())
printf("The input is invalid, we don't expect any result.\n");
else
{
printf("Expected result: \n");
for (int i = 0; i < k; ++i)
printf("%d\t", expectedResult[i]);
printf("\n");
}
printf("Result for solution:\n");
Solution s;
vector<int> output = s.GetLeastNumbers_Solution(data, k);
if (!expectedResult.empty())
{
for (int i = 0; i < k; ++i)
printf("%d\t", output[i]);
printf("\n");
}
}
// k小于数组的长度
void Test1()
{
vector<int> data = { 4, 5, 1, 6, 2, 7, 3, 8 };
vector<int> expected = { 1, 2, 3, 4 };
Test("Test1", data, expected, expected.size());
}
// k等于数组的长度
void Test2()
{
vector<int> data = { 4, 5, 1, 6, 2, 7, 3, 8 };
vector<int> expected = { 1, 2, 3, 4, 5, 6, 7, 8 };
Test("Test2", data, expected, expected.size());
}
// k大于数组的长度
void Test3()
{
vector<int> data = { 4, 5, 1, 6, 2, 7, 3, 8 };
vector<int> expected;
Test("Test3", data, expected, expected.size());
}
// k等于1
void Test4()
{
vector<int> data = { 4, 5, 1, 6, 2, 7, 3, 8 };
vector<int> expected = { 1 };
Test("Test4", data, expected, expected.size());
}
// k等于0
void Test5()
{
vector<int> data = { 4, 5, 1, 6, 2, 7, 3, 8 };
vector<int> expected;
Test("Test5", data, expected, expected.size());
}
// 数组中有相同的数字
void Test6()
{
vector<int> data = { 4, 5, 1, 6, 2, 7, 2, 8 };
vector<int> expected = { 1, 2 };
Test("Test6", data, expected, expected.size());
}
// 输入空指针
void Test7()
{
vector<int> expected;
Test("Test7", vector<int>(), expected, expected.size());
}
int main(int argc, char* argv[])
{
Test1();
Test2();
Test3();
Test4();
Test5();
Test6();
Test7();
return 0;
}