java8 PriorityQueue接口实现源码解析

一、类继承关系

二、使用说明

       PriorityQueue是支持排序的FIFO队列,内部实现是基于数组的二叉堆,数组的第一个元素为队列中最小的元素,队列的head,每次执行poll(),remove(),peek(),element()方法时都是操作队列的head元素。通过循环调用poll()方法直到返回null可保证按照排序好的顺序遍历,但是通过iterator()返回的迭代器遍历时遍历元素的顺序不是任何特定的顺序。PriorityQueue跟SortedMap一致,要求插入的元素不能为null,必须实现Comparable接口或者可以被构造时传入的Compartor实例比较,否则报错ClassCastException。PriorityQueue无容量限制,内部会根据需要自动扩容;非同步,如果要求线程安全,推荐使用PriorityBlockingQueue。测试用例如下:

public class User {
    private String userName;

    private Integer age;

    public User() {
    }

    public User(String userName, Integer age) {
        this.userName = userName;
        this.age = age;
    }

    public String getUserName() {
        return userName;
    }

    public void setUserName(String userName) {
        this.userName = userName;
    }

    public Integer getAge() {
        return age;
    }

    public void setAge(Integer age) {
        this.age = age;
    }

    @Override
    public String toString() {
        return "User{" +
                "userName='" + userName + '\'' +
                ", age=" + age +
                '}';
    }
}


   @Test
    public void test() throws Exception {
        Queue<String> test=new PriorityQueue<>();
        test.add("test2");
        test.add("test3");
        test.add("test5");
        test.add("test4");
        test.add("test");
        test.add("test6");
        int size=test.size();
        System.out.println(size);
        System.out.println("=========iterator========");
        for(String s:test){
            System.out.println(s);//不是按照排序或者插入的顺序来的,而是数组实际保存元素的顺序
        }
        System.out.println("=========poll========");
        for(int i=0;i<size;i++){
            System.out.println(i+":"+test.poll());//按照升序的顺序遍历的,每次获取的都是当前队列的最小值
        }
    }

    @Test
    public void test2() throws Exception {
        //通过传递特定Comparator实现倒序排序
        Queue<User> queue=new PriorityQueue<>(new Comparator<User>() {
            @Override
            public int compare(User o1, User o2) {
                return o2.getAge()-o1.getAge();
            }
        });
        queue.add(new User("shl",12));
        queue.add(new User("shl2",14));
        queue.add(new User("shl3",10));
        queue.add(new User("shl4",15));
        queue.add(new User("shl5",9));
        queue.add(new User("shl6",16));
        int size=queue.size();
        for(int i=0;i<size;i++){
            System.out.println(queue.poll());
        }

    }

三、二叉堆说明

      二叉堆是堆的一种,使用完全二叉树来实现。所谓完全二叉树,即高度为n的二叉树,其前n-1层必须被填满,第n层也要从左到右顺序填满。在二叉堆中,所有父节点的值均不大于(或不小于)其左右子节点的值,但对左右节点值的大小关系不做限制。若所有父节点值均不大于其左右子结点的值,这样的二叉堆叫做小根堆,小根堆根结点的值是该堆中所有结点的最小值;同样的,当所有父节点的值都不小于其左右孩子的值时,这样的对叫做大根堆,大根堆根结点的值为该堆所有结点的最大值,可以利用此特点快速找出最大值或者最小值,进而实现排序。因为二叉堆是一棵完全二叉树,不需要红黑树那样通过父节点保存对左右子节点引用的方式构建红黑树,所以通常采用数组来实现二叉堆,下图是一个已经平衡的二叉堆及其对应的存储数组。

 

观察可发现,基于数组实现的二叉堆,对于数组中任意位置的n上元素,其左孩子在[2n+1]位置上,右孩子[2(n+1)]位置,它的父亲则在[(n-1)>>1]上,而根的位置则是[0]。

四、接口实现

1、全局属性

    //默认的初始化容量
    private static final int DEFAULT_INITIAL_CAPACITY = 11;

    //实际保存元素的数组
    transient Object[] queue; 
    
    //记录元素个数
    private int size = 0;
    
    //排序用的比较器
    private final Comparator<? super E> comparator;

    //记录修改次数
    transient int modCount = 0; 
    
    //数组的最大大小
    private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;

2、构造方法

//comparator为null时,要求插入的元素必须实现Comparable接口
    public PriorityQueue() {
        this(DEFAULT_INITIAL_CAPACITY, null);
    }


    public PriorityQueue(int initialCapacity) {
        this(initialCapacity, null);
    }


    public PriorityQueue(Comparator<? super E> comparator) {
        this(DEFAULT_INITIAL_CAPACITY, comparator);
    }

    public PriorityQueue(int initialCapacity,
                         Comparator<? super E> comparator) {
        if (initialCapacity < 1)
            throw new IllegalArgumentException();
        this.queue = new Object[initialCapacity];
        this.comparator = comparator;
    }

    @SuppressWarnings("unchecked")
    public PriorityQueue(Collection<? extends E> c) {
        if (c instanceof SortedSet<?>) {
            SortedSet<? extends E> ss = (SortedSet<? extends E>) c;
            this.comparator = (Comparator<? super E>) ss.comparator();
            initElementsFromCollection(ss);
        }
        else if (c instanceof PriorityQueue<?>) {
            PriorityQueue<? extends E> pq = (PriorityQueue<? extends E>) c;
            this.comparator = (Comparator<? super E>) pq.comparator();
            initFromPriorityQueue(pq);
        }
        else {
            this.comparator = null;
            initFromCollection(c);
        }
    }


    @SuppressWarnings("unchecked")
    public PriorityQueue(PriorityQueue<? extends E> c) {
        this.comparator = (Comparator<? super E>) c.comparator();
        initFromPriorityQueue(c);
    }


    @SuppressWarnings("unchecked")
    public PriorityQueue(SortedSet<? extends E> c) {
        this.comparator = (Comparator<? super E>) c.comparator();
        initElementsFromCollection(c);
    }

    private void initFromPriorityQueue(PriorityQueue<? extends E> c) {
        if (c.getClass() == PriorityQueue.class) {
            //复制关键属性即可
            this.queue = c.toArray();
            this.size = c.size();
        } else {
            initFromCollection(c);
        }
    }


    private void initFromCollection(Collection<? extends E> c) {
        initElementsFromCollection(c);
        heapify();
    }

    //复制数组,校验元素是否为空
    private void initElementsFromCollection(Collection<? extends E> c) {
        Object[] a = c.toArray();
        // If c.toArray incorrectly doesn't return Object[], copy it.
        if (a.getClass() != Object[].class)
            a = Arrays.copyOf(a, a.length, Object[].class);
        int len = a.length;
        if (len == 1 || this.comparator != null)
            for (int i = 0; i < len; i++)
                //校验元素是否为空
                if (a[i] == null)
                    throw new NullPointerException();
        this.queue = a;
        this.size = a.length;
    }

    @SuppressWarnings("unchecked")
    private void heapify() {
        //将数组元素按照二叉堆的存储方式调整存储位置
        //对前一半的元素执行siftDown,因为执行过程中如果后一半的元素小于前面的则会与该元素交换位置,所以不需要对所有的元素都执行siftDown
        //这里也可对后一半的元素执行siftUp来构建二叉堆
        for (int i = (size >>> 1) - 1; i >= 0; i--)
            siftDown(i, (E) queue[i]);
    }

3、添加元素和扩容

   private void grow(int minCapacity) {
        int oldCapacity = queue.length;
        //原容量小于64则扩容一倍,否则扩容50%
        int newCapacity = oldCapacity + ((oldCapacity < 64) ?
                                         (oldCapacity + 2) :
                                         (oldCapacity >> 1));
        if (newCapacity - MAX_ARRAY_SIZE > 0)
            newCapacity = hugeCapacity(minCapacity);
        //将原数组复制到新数组中
        queue = Arrays.copyOf(queue, newCapacity);
    }

    private static int hugeCapacity(int minCapacity) {
        if (minCapacity < 0) // overflow
            throw new OutOfMemoryError();
        return (minCapacity > MAX_ARRAY_SIZE) ?
            Integer.MAX_VALUE :
            MAX_ARRAY_SIZE;
    }


    public boolean add(E e) {
        return offer(e);
    }


    public boolean offer(E e) {
        if (e == null)
            throw new NullPointerException();
        modCount++;
        //i表示插入数组的下标
        int i = size;
        //超过当前数组容量执行扩容
        if (i >= queue.length)
            grow(i + 1);
        size = i + 1;
        if (i == 0) //数组为空
            queue[0] = e;
        else
            //将元素e插入到数组下标为i的位置,为了二叉堆平衡可能调整实际插入的位置
            //因为二叉堆插入新的节点时总是从最下面一层开始插入,所以如果发生调整只能往上调整
            siftUp(i, e);
        return true;
    }

    /**
     * 插入元素x到指定的数组下标k上,为了保持二叉堆的平衡,如果k小于父节点则需要不断将x往上移动
     * 直到x大于或者等于他的父节点
     */
    private void siftUp(int k, E x) {
        //siftUpUsingComparator和siftUpComparable逻辑相同就是比较大小的方式不同
        if (comparator != null)
            siftUpUsingComparator(k, x);
        else
            siftUpComparable(k, x);
    }

    @SuppressWarnings("unchecked")
    private void siftUpComparable(int k, E x) {
        //如果x未实现Comparable接口,此处会抛出ClassCastException异常
        Comparable<? super E> key = (Comparable<? super E>) x;
        while (k > 0) {
            //计算父节点所在的数组下标
            int parent = (k - 1) >>> 1;
            Object e = queue[parent];
            //如果大于或者等于父节点则返回
            if (key.compareTo((E) e) >= 0)
                break;
            //如果小于父节点将父节点与x交换位置,继续往上遍历比较父节点,直到根节点为止
            queue[k] = e;
            k = parent;
        }
        queue[k] = key;
    }

    @SuppressWarnings("unchecked")
    private void siftUpUsingComparator(int k, E x) {
        while (k > 0) {
            int parent = (k - 1) >>> 1;
            Object e = queue[parent];
            if (comparator.compare(x, (E) e) >= 0)
                break;
            queue[k] = e;
            k = parent;
        }
        queue[k] = x;
    }

 4、删除元素

 private int indexOf(Object o) {
        if (o != null) {
            //遍历数组找到目标元素o在数组中的索引位置
            for (int i = 0; i < size; i++)
                if (o.equals(queue[i]))
                    return i;
        }
        return -1;
    }


    public boolean remove(Object o) {
        int i = indexOf(o);
        //数组中不存在该元素
        if (i == -1)
            return false;
        else {
            //移除指定下标i的元素
            removeAt(i);
            return true;
        }
    }


    @SuppressWarnings("unchecked")
    public E poll() {
        if (size == 0)
            return null;
        int s = --size;
        modCount++;
        //每次移除都是queue[0],然后将数组最后一个元素放到下标为0的位置
        E result = (E) queue[0];
        E x = (E) queue[s];
        queue[s] = null;
        if (s != 0)
            //数组最后一个元素肯定不是最小的,放到queue[0]即二叉堆顶部的时候必须往下调整位置
            siftDown(0, x);
        return result;
    }


    @SuppressWarnings("unchecked")
    private E removeAt(int i) {
        modCount++;
        int s = --size;
        if (s == i) //如果移除的是数组最后一个元素,直接置空,不影响二叉堆的平衡
            queue[i] = null;
        else {
            E moved = (E) queue[s];
            queue[s] = null;
            //moved元素是数组中最后一个元素,肯定位于二叉堆最下面的一层,因为下标i对应的节点跟moved节点不一定在同一个子树中,
            //而二叉堆不同子树无任何大小关系, 所以既可能大于也可能小于下标i节点对应的子节点
            //执行siftDown确保moved元素肯定小于其最小子节点元素
            siftDown(i, moved);
            //即插入到下标为i时,moved元素刚好小于其最小的子节点,此时moved元素可能小于其父节点,所以需要执行siftUp方法,
            // 确保moved元素肯定大于或者等于其父节点
            if (queue[i] == moved) {
                siftUp(i, moved);
                //此时的moved元素实际没有向下调整而是向上调整了,即转移到了下标为i的前面,按数组下标遍历时就无法遍历了
                //为了迭代器能够遍历该元素,将该元素返回,由迭代器另外处理
                if (queue[i] != moved)
                    return moved;
            }
        }
        return null;
    }



    /**
     * 与siftUp相反,不断地将元素往下移动直到目标元素x小于或者等于最小的子节点
     * 之所以是跟最小的子节点比较,是为了确保如果发生跟最小的子节点交换位置,最小的子节点变成父节点了,这时父节点必须都小于两个
     * 子节点
     */
    private void siftDown(int k, E x) {
        //siftDownUsingComparator和siftDownComparable逻辑相同,
        if (comparator != null)
            siftDownUsingComparator(k, x);
        else
            siftDownComparable(k, x);
    }

    @SuppressWarnings("unchecked")
    private void siftDownComparable(int k, E x) {
        Comparable<? super E> key = (Comparable<? super E>)x;
        int half = size >>> 1;
        while (k < half) { //假如下标k的元素存在左节点,则2k+1<=size,所以有k < size>>>1
            int child = (k << 1) + 1; //计算左节点下标
            Object c = queue[child];
            int right = child + 1;//计算右节点下标
            //如果左节点的值小于右节点则将c置为右节点的值,即c表示子节点中最小的值
            if (right < size && //表示存在右节点
                ((Comparable<? super E>) c).compareTo((E) queue[right]) > 0)
                c = queue[child = right];
            //如果父节点小于或者等于最小的子节点,则终止循环
            if (key.compareTo((E) c) <= 0)
                break;
            //交换key和最小子节点的位置,继续往下比较子节点
            queue[k] = c;
            k = child;
        }
        queue[k] = key;
    }

    @SuppressWarnings("unchecked")
    private void siftDownUsingComparator(int k, E x) {
        int half = size >>> 1;
        while (k < half) {
            int child = (k << 1) + 1;
            Object c = queue[child];
            int right = child + 1;
            if (right < size &&
                comparator.compare((E) c, (E) queue[right]) > 0)
                c = queue[child = right];
            if (comparator.compare(x, (E) c) <= 0)
                break;
            queue[k] = c;
            k = child;
        }
        queue[k] = x;
    }

5、元素遍历

    /**
     * 用==比较元素是否相同,迭代器使用,相比equals()方法更快
     */
    boolean removeEq(Object o) {
        for (int i = 0; i < size; i++) {
            if (o == queue[i]) {
                removeAt(i);
                return true;
            }
        }
        return false;
    }

    public Iterator<E> iterator() {
        return new Itr();
    }

    private final class Itr implements Iterator<E> {
        /**
         * 当前遍历元素的数组下标
         */
        private int cursor = 0;

        /**
         * 上一次返回元素的数组下标
         */
        private int lastRet = -1;

        /**
         * 存储那些因为siftup从未遍历的位置交互到已遍历的位置上的元素,确保这些元素也能被遍历到
         */
        private ArrayDeque<E> forgetMeNot = null;

        /**
         * 上一次返回的元素,遍历forgetMeNot中的元素时使用
         */
        private E lastRetElt = null;

        /**
         * 期望的修改次数
         */
        private int expectedModCount = modCount;

        public boolean hasNext() {
            return cursor < size ||
                (forgetMeNot != null && !forgetMeNot.isEmpty());
        }

        @SuppressWarnings("unchecked")
        public E next() {
            if (expectedModCount != modCount)
                throw new ConcurrentModificationException();
            //按照数组下标遍历
            if (cursor < size)
                return (E) queue[lastRet = cursor++];
            //数组遍历完成,遍历forgetMeNot中保存的元素
            if (forgetMeNot != null) {
                lastRet = -1;//表示数组遍历完成
                lastRetElt = forgetMeNot.poll();
                if (lastRetElt != null)
                    return lastRetElt;
            }
            throw new NoSuchElementException();
        }

        public void remove() {
            if (expectedModCount != modCount)
                throw new ConcurrentModificationException();
            if (lastRet != -1) {//遍历数组中
                E moved = PriorityQueue.this.removeAt(lastRet);
                lastRet = -1;
                //当removeAt方法返回null时,实际执行siftDown,lastRet处的元素会用lastRet后面尚未遍历到的元素代替,
                // 下次遍历时还是从lastRet开始,所以cursor需要减一
                if (moved == null)
                    cursor--;
                else {
                    //当removeAt方法返回非null元素时,实际执行siftUp,lastRet处的元素会用lastRet前面的元素代替,
                    // 该元素已经遍历过了,下一次遍历从下一个元素即cursor处开始
                    //返回的非null元素被交换到了lastRet前面还未遍历,所以需要单独保存起来,等数组遍历完成后再遍历
                    if (forgetMeNot == null)
                        forgetMeNot = new ArrayDeque<>();
                    forgetMeNot.add(moved);
                }
            } else if (lastRetElt != null) {//遍历forgetMeNot中的元素
                PriorityQueue.this.removeEq(lastRetElt);
                lastRetElt = null;
            } else {
                throw new IllegalStateException();
            }
            expectedModCount = modCount;
        }
    }

猜你喜欢

转载自blog.csdn.net/qq_31865983/article/details/87543139
今日推荐