CountDownLatch原理以及使用案例

CountDownLatch介绍

CountDownLatch是一个同步工具类,它允许一个或多个线程一直等待,直到其他线程执行完后再执行。

CountDownLatch是在java1.5被引入的,跟它一起被引入的并发工具类还有CyclicBarrier、Semaphore、ConcurrentHashMap和BlockingQueue,它们都存在于java.util.concurrent包下

CountDownLatch 定义了一个计数器,和一个阻塞队列, 当计数器的值递减为0之前,阻塞队列里面的线程处于挂起状态,当计数器递减到0时会唤醒阻塞队列所有线程,这里的计数器是一个标志,可以表示一个任务一个线程,也可以表示一个倒计时器,CountDownLatch可以解决那些一个或者多个线程在执行之前必须依赖于某些必要的前提业务先执行的场景。

CountDownLatch原理

CountDownLatch是通过一个计数器来实现的,计数器的初始化值为线程的数量。每当一个线程完成了自己的任务后,计数器的值就相应得减1。当计数器到达0时,表示所有的线程都已完成任务,然后在闭锁上等待的线程就可以恢复执行任务。

注:这是一个一次性操作 - 计数无法重置。 如果你需要一个重置的版本计数,考虑使用CyclicBarrier。

CountDownLatch原理示意图

CountDownLatch 内部结构

1.Sync 是一个静态内部类 继承 AbstractQueuedSynchronizer

 /**
     * CountDownLatch的同步控制器
     * 使用AQS状态表示计数。
     */
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;
 
        //构造方法 CountDownLatch 的构造方法最终调用的是 Sync的构造。
        Sync(int count) {
            setState(count); //初始化count
        }
 
        //返回当前count 计数
        int getCount() {
            return getState();
        }
 
        // 试图在共享模式下获取对象状态
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }
        // 试图设置状态来反映共享模式下的一个释放
        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

注:从源码可知,其底层是由AQS提供支持。AQS 见下章分解。

2. await()

           此函数将会使当前线程在锁存器倒计数至零之前一直等待,除非线程被中断。其源码如下  

    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
 
   //-----------------------------------------------------------
 
    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        //这里可以看到 最终调用的是Sync中的 tryAcquireShared 方法 
        // return (count == 0) ? 1 : -1;
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }
 
 
//-------------------------------------------------------
 
 
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

3.countDown() 

  此函数将递减锁存器的计数,如果计数到达零,则释放所有等待的线程  

   /**   
     * count值减 1,直到计数达到零,释放所有等待的线程 。
     *      
     *  <p>如果当前计数大于零,则递减。
     *   如果新计数为零,则重新启用所有等待的线程 ,达到线程调度的目的。
     *      
     * <p>如果当前计数等于零,则没有任何反应。
     */
    public void countDown() {
        sync.releaseShared(1);
    }
 
//-------------------------------------------------------------------------
   /**
     * 此函数会以共享模式释放对象,
     * 并且在函数中会调用到CountDownLatch的tryReleaseShared函数,
     * 当且仅当新计数返回0时,会调用AQS的doReleaseShared函数, 
     */
    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }

*CountDownLatch和CyclicBarrier区别:
1.countDownLatch是一个计数器,线程完成一个记录一个,计数器递减,只能只用一次
2.CyclicBarrier的计数器更像一个阀门,需要所有线程都到达,然后继续执行,计数器递增,提供reset功能,可以多次使用

CountDownLatch 使用示例

示例1

package com.kids.web.countdownlatch;
 
import java.util.concurrent.CountDownLatch;
 
public class Worker implements Runnable {
 
    private int num;
 
    private final CountDownLatch startSignal;
    private final CountDownLatch doneSignal;
 
    Worker(CountDownLatch startSignal, CountDownLatch doneSignal, int num) {
        this.startSignal = startSignal;
        this.doneSignal = doneSignal;
        this.num = num;
    }
 
    @Override
    public void run() {
        try {
            Thread.sleep(2000);
            System.out.println("do startSignal.await() before");
            startSignal.await();
            System.out.println("do startSignal.await() after");
 
            doWork();
            doneSignal.countDown();
            System.out.println("===count =" + doneSignal.getCount() + "========== num =" + num);
 
        } catch (InterruptedException ex) {
        } // return;
    }
 
    void doWork() {
        System.out.println("do work..." + num);
    }
}
package com.kids.web.countdownlatch;
 
import java.util.concurrent.CountDownLatch;
 
public class Driver {
    public static void main(String[] args) throws InterruptedException {
        CountDownLatch startSignal = new CountDownLatch(1);
        int N = 10;
        CountDownLatch doneSignal = new CountDownLatch(N);
 
        for (int i = 0; i < N; ++i) // create and start threads
            new Thread(new Worker(startSignal, doneSignal,i)).start();
 
        doSomethingElse1();            // don't let run yet
        Thread.sleep(1000 * 5);
        startSignal.countDown();      // let all threads proceed
        doSomethingElse2();
        Thread.sleep(1000 * 5);
        doneSignal.await();           // wait for all to finish
    }
 
    private static void doSomethingElse1() {
        System.out.println("====================startSignal.countDown() before");
    }
 
    private static void doSomethingElse2() {
        System.out.println("====================startSignal.countDown() after");
    }
}

====================startSignal.countDown() before
do startSignal.await() before ----count=1 ---- num =0
do startSignal.await() before ----count=1 ---- num =2
do startSignal.await() before ----count=1 ---- num =4
do startSignal.await() before ----count=1 ---- num =1
do startSignal.await() before ----count=1 ---- num =3
====================startSignal.countDown() after
do startSignal.await() after  ----count=0 ---- num =3
do startSignal.await() after  ----count=0 ---- num =1
do startSignal.await() after  ----count=0 ---- num =4
do startSignal.await() after  ----count=0 ---- num =2
do work...2
doneSignal.countDown()===count =4========== num =2
do startSignal.await() after  ----count=0 ---- num =0
do work...0
do work...4
doneSignal.countDown()===count =2========== num =4
do work...1
doneSignal.countDown()===count =1========== num =1
do work...3
doneSignal.countDown()===count =0========== num =3
doneSignal.countDown()===count =3========== num =0

等待其他线程执行完毕

Process finished with exit code 0

说明:通过示例结果中可以看出,在startSignal调用countDown()之前程序在startSignal.await() 处堵塞,

此时startSignal的count为1。在startSignal调用countDown()之后,count =0 时 开始执行startSignal.await()之后的业务逻辑。

doneSignal调用countDown() 直到其count=0 回到主线程。等待doneSignal调用await()结束。

注意:doneSignal最终不调用await() 的话 该线程始终处于等待状态。(本人也不知具体原因,在await 方法中 如果计数为0的话并没有处理任何事情。而且通过sleep 看到只有doneSignal的计数为0后才会返回主线程执行。)

示例2

package com.kids.web.countdownlatch;
 
import java.util.concurrent.CountDownLatch;
 
public class Worker2 implements Runnable {
 
    private final CountDownLatch doneSignal;
    private final int i;
 
    public Worker2(CountDownLatch doneSignal, int i) {
        this.doneSignal = doneSignal;
        this.i = i;
    }
 
    public void run() {
        doWork(i);
        doneSignal.countDown();
        System.out.println("=============count =" + doneSignal.getCount() + "------- i =" + i);
    }
 
    void doWork(int i) {
        System.out.println("do work..." + i);
    }
}
package com.kids.web.countdownlatch;
 
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
 
public class Driver2 {
    public static void main(String[] args) throws InterruptedException {
        int N = 20;
        CountDownLatch doneSignal = new CountDownLatch(N);
        Executor e = Executors.newFixedThreadPool(1);
 
        for (int i = 0; i < N; ++i) // create and start threads
            e.execute(new Worker2(doneSignal, i));
        doneSignal.await();           // wait for all to finish
    }
 
}

do work...0
=============count =19------- i =0
do work...1
=============count =18------- i =1
do work...2
=============count =17------- i =2
do work...3
=============count =16------- i =3
。。。。。。

。。。。。。

do work...18
=============count =1------- i =18
do work...19
=============count =0------- i =19

说明:这里 Executor e = Executors.newFixedThreadPool(1); 只是为了可以从打印结果中更直观的看出 count的递减。


实例3

用CountDownLatch 来优化我们的报表统计

功能现状

运营系统有统计报表、业务为统计每日的用户新增数量、订单数量、商品的总销量、总销售额......等多项指标统一展示出来,因为数据量比较大,统计指标涉及到的业务范围也比较多,所以这个统计报表的页面一直加载很慢,所以需要对统计报表这块性能需进行优化。

问题分析

统计报表页面涉及到的统计指标数据比较多,每个指标需要单独的去查询统计数据库数据,单个指标只要几秒钟,但是页面的指标有10多个,所以整体下来页面渲染需要将近一分钟。

解决方案

任务时间长是因为统计指标多,而且指标是串行的方式去进行统计的,我们只需要考虑把这些指标从串行化的执行方式改成并行的执行方式,那么整个页面的时间的渲染时间就会大大的缩短, 如何让多个线程同步的执行任务,我们这里考虑使用多线程,每个查询任务单独创建一个线程去执行,这样每个统计指标就可以并行的处理了。

要求

因为主线程需要每个线程的统计结果进行聚合,然后返回给前端渲染,所以这里需要提供一种机制让主线程等所有的子线程都执行完之后再对每个线程统计的指标进行聚合。 这里我们使用CountDownLatch 来完成此功能。

模拟代码

1、分别统计4个指标用户新增数量、订单数量、商品的总销量、总销售额;

2、假设每个指标执行时间为3秒。如果是串行化的统计方式那么总执行时间会为12秒。

3、我们这里使用多线程并行,开启4个子线程分别进行统计

4、主线程等待4个子线程都执行完毕之后,返回结果给前端。

​
    //用于聚合所有的统计指标
    private static Map map=new HashMap();
    //创建计数器,这里需要统计4个指标
    private static CountDownLatch countDownLatch=new CountDownLatch(4);
​
    public static void main(String[] args) {
​
        //记录开始时间
        long startTime=System.currentTimeMillis();
​
        Thread countUserThread=new Thread(new Runnable() {
            public void run() {
                try {
                    System.out.println("正在统计新增用户数量");
                    Thread.sleep(3000);//任务执行需要3秒
                    map.put("userNumber",1);//保存结果值
                    countDownLatch.countDown();//标记已经完成一个任务
                    System.out.println("统计新增用户数量完毕");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
​
            }
        });
        Thread countOrderThread=new Thread(new Runnable() {
            public void run() {
                try {
                    System.out.println("正在统计订单数量");
                    Thread.sleep(3000);//任务执行需要3秒
                    map.put("countOrder",2);//保存结果值
                    countDownLatch.countDown();//标记已经完成一个任务
                    System.out.println("统计订单数量完毕");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
​
            }
        });
​
        Thread countGoodsThread=new Thread(new Runnable() {
            public void run() {
                try {
                    System.out.println("正在商品销量");
                    Thread.sleep(3000);//任务执行需要3秒
                    map.put("countGoods",3);//保存结果值
                    countDownLatch.countDown();//标记已经完成一个任务
                    System.out.println("统计商品销量完毕");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
​
            }
        });
​
        Thread countmoneyThread=new Thread(new Runnable() {
            public void run() {
                try {
                    System.out.println("正在总销售额");
                    Thread.sleep(3000);//任务执行需要3秒
                    map.put("countmoney",4);//保存结果值
                    countDownLatch.countDown();//标记已经完成一个任务
                    System.out.println("统计销售额完毕");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
​
            }
        });
        //启动子线程执行任务
        countUserThread.start();
        countGoodsThread.start();
        countOrderThread.start();
        countmoneyThread.start();
​
        try {
            //主线程等待所有统计指标执行完毕
            countDownLatch.await();
            long endTime=System.currentTimeMillis();//记录结束时间
            System.out.println("------统计指标全部完成--------");
            System.out.println("统计结果为:"+map.toString());
            System.out.println("任务总执行时间为"+(endTime-startTime)/1000+"秒");
​
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
​
​
    }

执行结果

preview

猜你喜欢

转载自blog.csdn.net/wr_java/article/details/115127491