聊聊并发:(十三)concurrent包并发辅助类之CountDownLatch源码分析

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/wtopps/article/details/84952842

前言

聊聊并发:(十二)concurrent包并发辅助类之CyclicBarrier源码分析

上篇文章,我们分析了并发辅助类CyclicBarrier的源码实现,本篇,我们继续聊聊与它功能非常相似的一个类,CountDownLatch类的使用方法以及实现机制。

CountDownLatch介绍

CountDownLatch是一个同步辅助类,与CyclicBarrier功能相似,它允许一组线程互相等待,直到到达某个公共屏障点。在涉及一组固定大小的线程的程序中,这些线程必须不时地互相等待,此时CyclicBarrier与CountDownLatch很有用。但是它与CyclicBarrier不同的是,CyclicBarrier是可以复用的,CountDownLatch是不可以复用的,怎么理解这句话呢?

假设有十个线程同时执行,如果是使用CyclicBarrier,将其初始化大小为5,那么第一组5个线程执行完毕后,下一组5个线程执行await()方法,仍会进入等待状态;而CountDownLatch则不可以,当第一组5个线程执行完毕后,CountDownLatch将不再生效,必须执行reset()方法,恢复到初始化状态,才可以继续使用。

CountDownLatch构造方法如下:

CyclicBarrier(int parties) 
          创建一个新的 CyclicBarrier,它将在给定数量的参与者(线程)处于等待状态时启动,但它不会在启动 barrier 时执行预定义的操作。 
------------------------------------------------------------------------------
CyclicBarrier(int parties, Runnable barrierAction) 
          创建一个新的 CyclicBarrier,它将在给定数量的参与者(线程)处于等待状态时启动,并在启动 barrier 时执行给定的屏障操作,该操作由最后一个进入 barrier 的线程执行。 
------------------------------------------------------------------------------

CountDownLatch的方法列表如下:

 int await() 
          在所有参与者都已经在此 barrier 上调用 await 方法之前,将一直等待。 
------------------------------------------------------------------------------
 int await(long timeout, TimeUnit unit) 
          在所有参与者都已经在此屏障上调用 await 方法之前将一直等待,或者超出了指定的等待时间。
------------------------------------------------------------------------------
 int getNumberWaiting() 
          返回当前在屏障处等待的参与者数目。 
------------------------------------------------------------------------------
 int getParties() 
          返回要求启动此 barrier 的参与者数目。 
------------------------------------------------------------------------------
 boolean isBroken() 
          查询此屏障是否处于损坏状态。 
------------------------------------------------------------------------------
 void reset() 
          将屏障重置为其初始状态。 
------------------------------------------------------------------------------

CountDownLatch的方法列表基本与CyclicBarrier一致,这里不再赘述。

CountDownLatch使用示例

public class CountDownLatchDemo {
    public static void main(String[] args) throws Exception {
        CountDownLatch countDownLatch = new CountDownLatch(5);
        for (int i = 0; i < 5; i++) {
            new Thread(() -> {
                System.out.println("等待其他线程执行开始");
                countDownLatch.countDown();
            }).start();
        }
        countDownLatch.await();
        System.out.println("全部线程执行完毕");
    }
}

输出结果:

当前线程:Thread-1等待其他线程执行开始
当前线程:Thread-0等待其他线程执行开始
当前线程:Thread-2等待其他线程执行开始
当前线程:Thread-4等待其他线程执行开始
当前线程:Thread-3等待其他线程执行开始
全部线程执行完毕

这个示例非常的简单,我们创建了一个countDownLatch,设置其大小为5,然后新建5个线程,在线程方法中,执行await()操作,每次一个线程执行await(),计数器就会减一,直到减到0之前,其他线程都会等待。

根据结果我们也可以看到,当全部5个线程都执行完毕之后,才输出了"全部线程执行完毕"。

CountDownLatch源码实现

CountDownLatch的实现是基于AQS的同步队列,通过重写AQS的抽象方法,实现其功能,我们来看一下它的实现:

扫描二维码关注公众号,回复: 5638561 查看本文章
public class CountDownLatch {

    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }
    
    private final Sync sync;
    
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }
    
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
    
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }
    
    public void countDown() {
        sync.releaseShared(1);
    }
    
    public long getCount() {
        return sync.getCount();
    }
    
    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }
}

CountDownLatch的构造函数是构造一个用给定计数初始化的CountDownLatch,并且构造函数内完成了sync的初始化,并设置了AQS的state值。

Sync继承于AQS,对AQS不了解的朋友可以看这篇AbstractQueuedSynchronizer源码分析,重写了其tryAcquireShared、tryReleaseShared这两个方法,这个后面我们会提到。

我们来分析一下最核心的两个方法,await()、countDown()。

await()

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

await()的实现是基于AQS的,调用了其acquireSharedInterruptibly()方法,是AQS的共享式获取同步状态,我们在前面介绍AQS的时候提到过,我们再来看一下它的实现:

public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    //调用模板方法子类的实现,即CountdownLatch的实现
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

这段代码比较简单,首先判断了线程是否被中断,如果是,抛出异常,否则,调用子类的模板方法的实现:

protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

当同步状态不为0的时候,返回-1,进入doAcquireSharedInterruptibly()方法:

private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt()) //挂起当前线程
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

这里是同步队列构建的过程,由于在之前的文章中我们已经介绍过,这里就简单说一下其流程:

  • 1、增加一个共享节点到同步队列尾部,如果当前队列中没有节点,创建一个假的头结点,将新的节点的前驱节点指向该假头结点
  • 2、进入自旋,获取当前节点的前驱节点
  • 3、如果是头结点,调用模板方法的实现获取同步状态
  • 4、如果结果大于0,设置新的头结点,并释放掉当前节点,并跳出循环
  • 5、否则,将当前线程挂起
  • 6、过程中如果被打断,会抛出中断异常
//挂起当前线程
private final boolean parkAndCheckInterrupt() {
    LockSupport.park(this);
    return Thread.interrupted();
}

上述的过程,即将调用await方法的线程,加入同步队列当中去,将其挂起,直到其被唤醒。
如果您之前没有了解过AQS的实现,想必您看到这里,还是有点蒙的,请继续往下看,我们再说完countDown()方法后,我们使用一张图来描述整个过程,您就会一切明朗。

countDown():

public void countDown() {
    sync.releaseShared(1);
}
public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

countDown()方法同样是基于AQS的方法进行实现,调用其releaseShared(),共享式获取同步状态方法,调用其模板方法的实现:

protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        // 获取状态
        int c = getState();
        // 如果状态值为0
        if (c == 0)
            return false;
        //否则,将状态值,减一
        int nextc = c-1;
        //CAS赋值
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

上面的方法是countdownLatch重写AQS的模板方法进行的实现,逻辑比较简单,请见注释。

我们知道,调用countdown的线程数,达到了countdownLatch的构造函数值的时候,countdownLatch会唤醒全部等待的线程,从tryReleaseShared方法我们可以看到,每次一个线程调用其countdown的时候,都会对state进行减一操作,直到state为0的时候,该方法返回true,即执行doReleaseShared方法:

private void doReleaseShared() {
    for (;;) {
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        if (h == head)                   // loop if head changed
            break;
    }
}

这个方法中最重要的是unparkSuccessor()方法,它会唤醒当前节点的线程。

private void unparkSuccessor(Node node) {
    int ws = node.waitStatus;
    if (ws < 0)
        compareAndSetWaitStatus(node, ws, 0);
    Node s = node.next;
    if (s == null || s.waitStatus > 0) {
        s = null;
        for (Node t = tail; t != null && t != node; t = t.prev)
            if (t.waitStatus <= 0)
                s = t;
    }
    if (s != null)
        LockSupport.unpark(s.thread);
}

我们来从头梳理一遍这个逻辑:

  • 1、线程执行await方法,进入同步队列,判断当前节点的前驱节点是否为头结点,如果不是,线程挂起,进行等待;
  • 2、线程调用tryReleaseShared()方法时,会将当前同步状态减一,当state为0的时候,调用doReleaseShared();
  • 3、调用doReleaseShared做了一件事,唤醒头结点的线程;
  • 4、头结点线程(第一个节点的头结点是假节点,没有持有线程,会唤醒其下一个真实节点的线程)被唤醒后,拿到同步资源,退出循环,并唤醒下一个节点的线程,依次类推,直到唤醒全部同步队列的线程

我们用一张图描述一下:

image
在这里插入图片描述
上面是CountdownLatch简要的流程,这里最重要的其实在于countdown方法,将state置为0的时候,对头部节点的唤醒,这个我认为是整个设计的精髓所在,希望读者看到这里的时候,可以自己去翻一下源码,自己去理解一下其设计的方式。

CountDownLatch VS CyclicBarrier

这两个类真的非常像,它们都能够实现线程之间的等待,只不过它们侧重点不同,其区别有以下几点:

1、CountDownLatch一般用于某个线程A等待若干个其他线程执行完任务之后,它才执行;
2、而CyclicBarrier一般用于一组线程互相等待至某个状态,然后这一组线程再同时执行;
3、CountDownLatch是不能够重用的,而CyclicBarrier是可以重用的。

我们在开发并发程序的时候,可以根据具体的业务场景,进行选择使用哪一种。

结语

本篇,我们介绍了CountDownLatch的使用方法及其实现机制,希望可以对您有所帮助,CountDownLatch的实现是基于AQS的,如果没有了解过AQS的读者,可以看一下之前的文章 聊聊并发:(八)concurrent包之AbstractQueuedSynchronizer源码实现分析

下篇预告:聊聊并发:(十四)concurrent包并发辅助类之Semaphore分析

更多Java干货文章请关注我的个人微信公众号:老宣与你聊Java

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/wtopps/article/details/84952842
今日推荐