手写堆排序 python

import numpy as np


class HeapSort(object):
    # 这是堆中元素的个数
    __count = 0
    # 这是堆得最大容量
    __capacity = 50
    arr = [0]*__capacity
    def __init__(self):
        pass

    def get_capacity(self):
        return self.__capacity

    def extend_capacity(self):
        self.__capacity += 30
        self.arr.extend([0]*30)

    def get_count(self):
        return self.__count

    def init_heap(self, nums):
        while self.get_capacity() + 1 < len(nums):
            self.extend_capacity()
            print("扩容+30 成功")
        print(self.arr)
        for i, item in enumerate(nums):
            self.arr[i+1] = item
            self.__count += 1
        print("the count is {}".format(self.__count))

        print("origin is :")
        print(self.arr)
        self.adjust()
        print("init_heap result is :")
        print(self.arr)

    # 从第一个位置开始, 调整堆
    # 判断左右节点是否存在, 若存在并且子节点大于父节点,那么交换,
    # 继续向下进行, 直到不满足条件终止
    def shift_down(self, pos):
        j = 2*pos
        while j <= self.get_count():
            # 如果右边节点存在
            if j + 1 <= self.get_count():
                if self.arr[j+1] > self.arr[j]:
                    j = j + 1
            if self.arr[pos] < self.arr[j]:
                self.arr[pos], self.arr[j] = self.arr[j], self.arr[pos]
                pos = j
                j = 2*j
            else:
                break


    def adjust(self):
        n = self.get_count() // 2
        while n >= 1:
            self.shift_down(n)
            n -= 1
        # adjust_count = self.get_count() // 2
        # n = adjust_count
        # while adjust_count >= 1:
        #     # j代表着要被交换的哪一个
        #     j = n*2
        #     while 2*n <= self.__count:
        #         # 挑选两者之中最大的一个
        #         if j+1 <= self.__count and self.arr[j+1] > self.arr[j]:
        #             j = j + 1
        #         if self.arr[j] > self.arr[n]:
        #             self.arr[j], self.arr[n] = self.arr[n], self.arr[j]
        #             n = j
        #             j = j*2
        #         else:
        #             break
        #     adjust_count -= 1

    def is_big_heap(self):
        n = self.get_count() // 2
        while n >= 1:
            if self.arr[n] >= self.arr[2*n]:
                if 2*n + 1 <= self.get_count():
                    if self.arr[n] >= self.arr[2*n+1]:
                        n -= 1
                        continue
                    else:
                        return False
            else:
                return False
            n -= 1
        return True

    # 去除最大的一个数, 然后调整堆为大顶堆
    def get_max(self):
        max_num = self.arr[1]
        self.arr[1], self.arr[self.get_count()] = self.arr[self.get_count()], self.arr[1]
        self.__count -= 1
        self.adjust()
        # print("after get_max: ")
        # print(self.arr)
        # print(self.is_big_heap())
        return max_num

    def heap_sort(self):
        res = [0]*self.get_count()
        while self.get_count() > 0:
            res.append(self.get_max())
        print("heap_sort result is :")
        print(res[::-1])

    def insert(self, num):
        if self.__capacity <= self.__count:
            self.extend_capacity()
        self.__count = self.get_count() + 1

        self.arr[self.__count] = num
        self.arr[self.__count], self.arr[1] = self.arr[self.__count], self.arr[1]
        self.adjust()
        print("Afet insert :\n{}".format(self.arr))



    def print_heap(self):
        print("heap status is :")
        print(self.arr)


if __name__ == '__main__':
    nums = np.random.randint(0, 100, 20)
    print(nums)
    heap1 = HeapSort()
    heap1.init_heap(nums)
    heap1.insert(55)
    print("Is a big_heap ? : {}".format(heap1.is_big_heap()))
    heap1.heap_sort()

    heap1.print_heap()

猜你喜欢

转载自blog.csdn.net/weixin_36149892/article/details/80412828