Phaser 源码分析

Phaser

一个可重用的同步屏障,功能上与 CyclicBarrier 和 CountDownLatch 类似,但是支持更灵活的使用。
把多个线程协作执行的任务划分为多个阶段,编程时需要明确各个阶段的任务,每个阶段都有指定个参与者,
线程都可以随时注册并参与到某个阶段。

创建实例

    /**
     * Primary state representation, holding four bit-fields:
     *
     * unarrived:还未到达的参与者数目 (bits  0-15)
     * parties:当前阶段总的参与者数目  (bits 16-31)
     * phase:屏障所处的阶段                (bits 32-62)
     * terminated:屏障是否终止             (bit  63 / sign)
     */
    private volatile long state;

    // 最大的参与者数目
    private static final int  MAX_PARTIES     = 0xffff;
    // 最大的阶段值
    private static final int  MAX_PHASE       = Integer.MAX_VALUE;
    // 参与者移位
    private static final int  PARTIES_SHIFT   = 16;
    // 阶段移位
    private static final int  PHASE_SHIFT     = 32;
    // 未到达参与者数掩码
    private static final int  UNARRIVED_MASK  = 0xffff;      // to mask ints
    // 总参与者数掩码
    private static final long PARTIES_MASK    = 0xffff0000L; // to mask longs
    private static final long COUNTS_MASK     = 0xffffffffL;
    // 终止位
    private static final long TERMINATION_BIT = 1L << 63;

    // 一个到达者
    private static final int  ONE_ARRIVAL     = 1;
    // 一个参与者
    private static final int  ONE_PARTY       = 1 << PARTIES_SHIFT;
    // 撤销一个参与者
    private static final int  ONE_DEREGISTER  = ONE_ARRIVAL|ONE_PARTY;
    // 0 个参与者,1 个达到者
    private static final int  EMPTY           = 1;

    /**
     *  此 Phaser 的 parent
     */
    private final Phaser parent;

    /**
     * The root of phaser tree. Equals this if not in a tree.
     */
    private final Phaser root;

    // 阶段值为偶数时, Treiber 栈节点,阻塞线程驻留在节点上
    private final AtomicReference<QNode> evenQ;
    // 阶段值为奇数时, Treiber 栈节点,阻塞线程驻留在节点上
    private final AtomicReference<QNode> oddQ;

    /**
     *  创建一个无父 Phaser、参与者数为 0 的 Phaser
     */
    public Phaser() {
        this(null, 0);
    }

    /**
     *  创建一个无父 Phaser、参与者数为 parties 的 Phaser
     */
    public Phaser(int parties) {
        this(null, parties);
    }

    /**
     * 创建一个父 Phaser 为 parent、参与者数为 0 的 Phaser
     */
    public Phaser(Phaser parent) {
        this(parent, 0);
    }

    /**
     * 创建一个父 Phaser 为 parent、参与者数为 parties 的 Phaser
     */
    public Phaser(Phaser parent, int parties) {
        if (parties >>> PARTIES_SHIFT != 0) {
            throw new IllegalArgumentException("Illegal number of parties");
        }
        // 初始阶段值为 0
        int phase = 0;
        // 写入父  Phaser
        this.parent = parent;
        // 1)如果存在父 Phaser
        if (parent != null) {
            final Phaser root = parent.root;
            this.root = root;
            evenQ = root.evenQ;
            oddQ = root.oddQ;
            if (parties != 0) {
                phase = parent.doRegister(1);
            }
        }
        // 2)不存在父 Phaser
        else {
            root = this;
            evenQ = new AtomicReference<>();
            oddQ = new AtomicReference<>();
        }
        // 计算状态值
        state = parties == 0 ? (long)EMPTY :
            (long)phase << PHASE_SHIFT |
            (long)parties << PARTIES_SHIFT |
            parties;
    }

达到屏障

  • 到达当前阶段,此方法不会阻塞
    /**
     *  到达当前阶段,此方法不会阻塞
     */
    public int arrive() {
        return doArrive(ONE_ARRIVAL);
    }

    private int doArrive(int adjust) {
        final Phaser root = this.root;
        for (;;) {
            /**
             * 1)如果无 root Phaser,即 root==this,则返回其状态值
             * 2)如果有 root Phaser
             */
            long s = root == this ? state : reconcileState();
            // 读取阶段值
            int phase = (int) (s >>> PHASE_SHIFT);
            if (phase < 0) {
                return phase;
            }
            final int counts = (int) s;
            // 计算未到达屏障的参与者数目
            final int unarrived = counts == EMPTY ? 0 : counts & UNARRIVED_MASK;
            // 如果都已经到达屏障,则抛出 IllegalStateException
            if (unarrived <= 0) {
                throw new IllegalStateException(badArrive(s));
            }
            // 递减达到者数目
            if (STATE.compareAndSet(this, s, s -= adjust)) {
                // 1)如果当前线程是最后一个达到屏障的线程
                if (unarrived == 1) {
                    long n = s & PARTIES_MASK; // base of next state
                    // 计算下一阶段的参与者数目
                    final int nextUnarrived = (int) n >>> PARTIES_SHIFT;
                    // 如果无 root Phaser
                    if (root == this) {
                        // 1)触发 onAdvance:注册的参与者数目==0
                        if (onAdvance(phase, nextUnarrived)) {
                            n |= TERMINATION_BIT;
                            // 2)下一阶段的参与者数目为 0
                        } else if (nextUnarrived == 0) {
                            n |= EMPTY;
                        } else {
                            // 3)将参与者数目写入新的状态变量
                            n |= nextUnarrived;
                        }
                        // 计算下一阶段值
                        final int nextPhase = phase + 1 & MAX_PHASE;
                        // 写入阶段值
                        n |= (long) nextPhase << PHASE_SHIFT;
                        // 更新状态变量
                        STATE.compareAndSet(this, s, n);
                        // 释放在此阶段阻塞的所有线程
                        releaseWaiters(phase);
                        // 2)如果存在 root Phaser && 下一阶段的参与者数目为 0
                    } else if (nextUnarrived == 0) { // propagate deregistration
                        // 向上传播取消注册一个参与者
                        phase = parent.doArrive(ONE_DEREGISTER);
                        STATE.compareAndSet(this, s, s | EMPTY);
                    } else {
                        // 3)如果存在 root Phaser && 下一阶段的参与者数目 > 0
                        // 父 Phaser 到达
                        phase = parent.doArrive(ONE_ARRIVAL);
                    }
                }
                // 返回阶段值
                return phase;
            }
        }
    }
  • 到达此阶段,并阻塞等待其他参与者到达,等价于 {@code awaitAdvance(arrive())}.
    /**
     *  到达此阶段,并阻塞等待其他参与者到达,等价于 {@code awaitAdvance(arrive())}.
     */
    public int arriveAndAwaitAdvance() {
        // Specialization of doArrive+awaitAdvance eliminating some reads/paths
        final Phaser root = this.root;
        for (;;) {
            // 读取状态值
            long s = root == this ? state : reconcileState();
            // 读取阶段值
            final int phase = (int) (s >>> PHASE_SHIFT);
            if (phase < 0) {
                return phase;
            }
            final int counts = (int) s;
            // 计算未到达的参与者数目
            final int unarrived = counts == EMPTY ? 0 : counts & UNARRIVED_MASK;
            if (unarrived <= 0) {
                throw new IllegalStateException(badArrive(s));
            }
            // 未到达参与者数-1
            if (STATE.compareAndSet(this, s, s -= ONE_ARRIVAL)) {
                // 如果还有参与者未到达,则阻塞等待
                if (unarrived > 1) {
                    return root.internalAwaitAdvance(phase, null);
                }
                // 如果存在 root Phaser,阻塞等待所有父参与者到达其 Phaser
                if (root != this) {
                    return parent.arriveAndAwaitAdvance();
                }
                // 计算下一阶段的参与者数目
                long n = s & PARTIES_MASK; // base of next state
                // 计算下一阶段的未到达参与者数目
                final int nextUnarrived = (int) n >>> PARTIES_SHIFT;
                // 1)触发 onAdvance:注册的参与者数目==0
                if (onAdvance(phase, nextUnarrived)) {
                    n |= TERMINATION_BIT;
                    // 2)下一阶段的参与者数目为 0
                } else if (nextUnarrived == 0) {
                    n |= EMPTY;
                } else {
                    // 3)将参与者数目写入新的状态变量
                    n |= nextUnarrived;
                }
                // 计算下一阶段值
                final int nextPhase = phase + 1 & MAX_PHASE;
                n |= (long) nextPhase << PHASE_SHIFT;
                // 写入下一阶段值
                if (!STATE.compareAndSet(this, s, n)) {
                    return (int) (state >>> PHASE_SHIFT); // terminated
                }
                // 释放在此阶段阻塞等待的线程
                releaseWaiters(phase);
                return nextPhase;
            }
        }
    }

    private int internalAwaitAdvance(int phase, QNode node) {
        // 尝试释放在前一阶段阻塞的所有参与者
        releaseWaiters(phase - 1); // ensure old queue clean
        // 是否已经加入队列
        boolean queued = false; // true when node is enqueued
        int lastUnarrived = 0; // to increase spins upon change
        /**
         * 计算自旋值,单核 CPU 为 1,多核 CPU 为 256
         */
        int spins = SPINS_PER_ARRIVAL;
        long s;
        int p;
        // 目标阶段 phase 和当前阶段相等
        while ((p = (int) ((s = state) >>> PHASE_SHIFT)) == phase) {
            // 1)以不可中断的方式自旋
            if (node == null) { // spinning in noninterruptible mode
                final int unarrived = (int) s & UNARRIVED_MASK;
                // 未到达参与者数 != 上次的未到达数 && 未到达参与者数目 < CPU 核数
                if (unarrived != lastUnarrived && (lastUnarrived = unarrived) < NCPU) {
                    // 递增自旋次数
                    spins += SPINS_PER_ARRIVAL;
                }
                // 读取当前线程的中断状态
                final boolean interrupted = Thread.interrupted();
                // 1)线程已经中断 || 自旋完成
                if (interrupted || --spins < 0) { // need node to record intr
                    // 创建新的节点
                    node = new QNode(this, phase, false, false, 0L);
                    // 写入中断状态
                    node.wasInterrupted = interrupted;
                    // 2)执行自旋
                } else {
                    Thread.onSpinWait();
                }
            }
            // 2)节点不可释放
            else if (node.isReleasable()) {
                break;
                // 3)节点还未入队
            } else if (!queued) { // push onto queue
                // 根据阶段值计算目标队列
                final AtomicReference<QNode> head = (phase & 1) == 0 ? evenQ : oddQ;
                // 将当前节点链接到 Treiber 栈顶
                final QNode q = node.next = head.get();
                /**
                 *  1)如果是第一个入队的节点 || 阶段值相等 && 
                 *  当前阶段就是目标阶段
                 */
                if ((q == null || q.phase == phase) && (int) (state >>> PHASE_SHIFT) == phase) {
                    // 更新头部节点
                    queued = head.compareAndSet(q, node);
                }
            }
            // 4)节点已经入队
            else {
                try {
                    // 阻塞当前线程
                    ForkJoinPool.managedBlock(node);
                } catch (final InterruptedException cantHappen) {
                    node.wasInterrupted = true;
                }
            }
        }

        if (node != null) {
            // 驻留线程不为 null,则将其清空
            if (node.thread != null) {
                node.thread = null; // avoid need for unpark()
            }
            // 节点是被中断的 && 线程可被中断
            if (node.wasInterrupted && !node.interruptible) {
                // 则中断当前线程
                Thread.currentThread().interrupt();
            }
            // 阶段一般已经更新了
            if (p == phase && (p = (int) (state >>> PHASE_SHIFT)) == phase) {
                return abortWait(phase); // possibly clean up on abort
            }
        }
        // 释放在此阶段阻塞等待的线程
        releaseWaiters(phase);
        return p;
    }

    static final class QNode implements ForkJoinPool.ManagedBlocker {
        // 所在 Phaser 引用
        final Phaser phaser;
        // 阶段值
        final int phase;
        // 线程是否可中断
        final boolean interruptible;
        // 线程是否可超时
        final boolean timed;
        // 线程是否被中断
        boolean wasInterrupted;
        // 超时时间
        long nanos;
        // 截止时间
        final long deadline;
        // 阻塞线程
        volatile Thread thread; // nulled to cancel wait
        // 下一个节点
        QNode next;
    }
  • 到达此阶段 && 将下一阶段的参与者数目-1
    /**
     *  到达此阶段 && 将下一阶段的参与者数目-1
     */
    public int arriveAndDeregister() {
        return doArrive(ONE_DEREGISTER);
    }

尝试在指定的阶段阻塞等待

    /**
     *  如果当前阶段值==phase,则阻塞等待
     */
    public int awaitAdvance(int phase) {
        final Phaser root = this.root;
        final long s = root == this ? state : reconcileState();
        // 读取阶段值
        final int p = (int) (s >>> PHASE_SHIFT);
        if (phase < 0) {
            return phase;
        }
        // 当前阶段和目标 phase 相等,则阻塞等待
        if (p == phase) {
            return root.internalAwaitAdvance(phase, null);
        }
        return p;
    }

    /**
     * 如果当前阶段值==phase,则阻塞等待,支持线程中断
     */
    public int awaitAdvanceInterruptibly(int phase) throws InterruptedException {
        final Phaser root = this.root;
        final long s = root == this ? state : reconcileState();
        int p = (int) (s >>> PHASE_SHIFT);
        if (phase < 0) {
            return phase;
        }
        if (p == phase) {
            // 创建一个支持线程中断的 QNode
            final QNode node = new QNode(this, phase, true, false, 0L);
            p = root.internalAwaitAdvance(phase, node);
            // 如果当前线程被中断
            if (node.wasInterrupted) {
                // 则抛出 InterruptedException 异常
                throw new InterruptedException();
            }
        }
        return p;
    }

    /**
     * 如果当前阶段值==phase,则阻塞等待直到超时,支持线程中断
     */
    public int awaitAdvanceInterruptibly(int phase, long timeout, TimeUnit unit)
            throws InterruptedException, TimeoutException {
        // 计算超时时间
        final long nanos = unit.toNanos(timeout);
        final Phaser root = this.root;
        final long s = root == this ? state : reconcileState();
        int p = (int) (s >>> PHASE_SHIFT);
        if (phase < 0) {
            return phase;
        }
        if (p == phase) {
            // 创建一个支持中断和超时的 QNode
            final QNode node = new QNode(this, phase, true, true, nanos);
            p = root.internalAwaitAdvance(phase, node);
            if (node.wasInterrupted) {
                throw new InterruptedException();
            } else if (p == phase) {
                throw new TimeoutException();
            }
        }
        return p;
    }

注册参与者

    /**
     *  注册一个参与者
     */
    public int register() {
        return doRegister(1);
    }

    private int doRegister(int registrations) {
        // 计算增量
        final long adjust = (long) registrations << PARTIES_SHIFT | registrations;
        final Phaser parent = this.parent;
        int phase;
        for (;;) {
            long s = parent == null ? state : reconcileState();
            final int counts = (int) s;
            // 计算总参与者数
            final int parties = counts >>> PARTIES_SHIFT;
            // 计算未到达参与者数
            final int unarrived = counts & UNARRIVED_MASK;
            // 注册参与者数超出上限
            if (registrations > MAX_PARTIES - parties) {
                throw new IllegalStateException(badRegister(s));
            }
            phase = (int) (s >>> PHASE_SHIFT);
            if (phase < 0) {
                break;
            }
            // 1)不是第一个注册者【已经注册了参与者】
            if (counts != EMPTY) { // not 1st registration
                if (parent == null || reconcileState() == s) {
                    // 1)所有的参与者都已经到达,则等待阶段更新
                    if (unarrived == 0) {
                        root.internalAwaitAdvance(phase, null);
                        // 2)递增参与者数目,出现竞争,则进行重试
                    } else if (STATE.compareAndSet(this, s, s + adjust)) {
                        break;
                    }
                }
                // 2)第一个根注册者
            } else if (parent == null) { // 1st root registration
                final long next = (long) phase << PHASE_SHIFT | adjust;
                if (STATE.compareAndSet(this, s, next)) {
                    break;
                }
            } else {
                synchronized (this) { // 1st sub registration
                    // 检查是否出现竞争
                    if (state == s) { // recheck under lock
                        // 注册一个参与者
                        phase = parent.doRegister(1);
                        if (phase < 0) {
                            break;
                        }
                        // 父 Phaser 注册成功后执行注册
                        while (!STATE.weakCompareAndSet(this, s, (long) phase << PHASE_SHIFT | adjust)) {
                            s = state;
                            phase = (int) (root.state >>> PHASE_SHIFT);
                            // assert (int)s == EMPTY;
                        }
                        break;
                    }
                }
            }
        }
        return phase;
    }
    
    // 同时注册 parties 个参与者
    public int bulkRegister(int parties) {
        if (parties < 0) {
            throw new IllegalArgumentException();
        }
        if (parties == 0) {
            return getPhase();
        }
        return doRegister(parties);
    }

猜你喜欢

转载自www.cnblogs.com/zhuxudong/p/10123826.html