【JUC源码】CountDownLatch源码分析&使用示例

CountDownLatch 中文有的叫做计数器,也有翻译为计数锁,其最大的作用不是为了加锁,而是通过计数达到等待的功能,主要有两种形式的等待:

  1. 让一组线程在全部启动完成之后,再一起执行(先启动的线程需要阻塞等待后启动的线程,直到一组线程全部都启动完成后,再一起执行);
  2. 主线程等待另外一组线程都执行完成之后,再继续执行。

1.结构

CountDownLatch 的核心成员变量及主要构造函数如下:

public class CountDownLatch {
    
    
	// 从 Sync 的继承关系就可以看出,CountDownLatch也是基于AQS框架实现的
    private static final class Sync extends AbstractQueuedSynchronizer {
    
    
        private static final long serialVersionUID = 4982264981922014374L;
		
		// 构造函数,直接设置state=count
        Sync(int count) {
    
    
            setState(count);
        }
		// 调用AQS方法获取state
        int getCount() {
    
    
            return getState();
        }
		// 能否获取到共享锁。如果当前同步器的状态是 0 的话,表示可获得锁
        protected int tryAcquireShared(int acquires) {
    
    
            return (getState() == 0) ? 1 : -1; // state!=0,就拿锁失败
        }
		// 对 state 进行递减,直到 state 变成 0;state 递减为 0 时,返回 true,其余返回 false
        protected boolean tryReleaseShared(int releases) {
    
    
            // 自旋保证 CAS 一定可以成功
            for (;;) {
    
    
                int c = getState();
                // state 已经是 0 了,直接返回 false
                if (c == 0)
                    return false;
                // state--
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }
    
    private final Sync sync;
    
    //-----------------------------构造函数------------------------------------
    // 无空参构造,必须传入count,count相当于要等待的线程数
    public CountDownLatch(int count) {
    
    
        if (count < 0) throw new IllegalArgumentException("count < 0");
        // 将count传给sync
        this.sync = new Sync(count);
    }
}

在看具体方法的源码前,先放上一个简单的使用示例:模拟了100米赛跑,10名选手已经准备就绪,只等裁判一声令下。当所有人都到达终点时,比赛结束。我们和容易想到主线程模拟裁判,开启十个子线程模拟运动员,但是会有以下两个问题:

  1. 10个子线程必须等到主线程发出号令(控制台打印Game Start)之后再运行
  2. 主线程必须等待10个子线程运行结束之后再退出

下面就看看如何通过 2个 CountDownLatch 来解决这俩问题。

public class CountDownLatchTest {
    
    

    public static void main(String[] args) throws InterruptedException {
    
    

        // 开始的倒数锁。count设置为1 是因为主线程只用-1
        final CountDownLatch begin = new CountDownLatch(1);  

        // 结束的倒数锁,count设置为10 是因为有10个子线程都要-1
        final CountDownLatch end = new CountDownLatch(10);  

        // 通过线程池创造出十名选手 (十个线程)
        final ExecutorService exec = Executors.newFixedThreadPool(10);  
		
		// 让这十个线程运行起来
        for (int index = 0; index < 10; index++) {
    
    
            final int NO = index + 1;  // 编号【1,10】
            Runnable run = new Runnable() {
    
    
                public void run() {
    
      
                    try {
    
      
                        // 如果当前计数为零(主线程已就绪),则此方法立即返回。
                        // 如果当前计数不为0(主线程还未调用countDown),等待。
                        begin.await();  
                        Thread.sleep((long) (Math.random() * 10000));  
                        System.out.println("No." + NO + " arrived");  
                    } catch (InterruptedException e) {
    
      
                    } finally {
    
      
                        // 每个选手到达终点(线程执行完毕)时,end就减一
                        end.countDown();
                    }  
                }  
            };  
            exec.submit(run);
        }  
        System.out.println("Game Start");  
        // begin减一,开始游戏
        begin.countDown();  
        // 主线程会阻塞在这里,等待end变为0,即所有选手到达终点
        end.await();  
        System.out.println("Game Over");  
        exec.shutdown();  
    }
}

控制台输出结果如下:

Game Start
No.9 arrived
No.6 arrived
No.8 arrived
No.7 arrived
No.10 arrived
No.1 arrived
No.5 arrived
No.4 arrived
No.2 arrived
No.3 arrived
Game Over

2.方法解析 & api

CountDownLatch 实质上就是使用了AQS共享锁模式

  1. 构造时传入count,将count赋值给AQS的state
  2. await:将当前线程加入同步队列,休眠。
  3. countDown:state--,当 state=0 时唤醒所有阻塞在await的线程恢复运行

注:这里为什么不用独占锁?因为await的位置可能有多个,等state=0后需要都唤醒

2.1 await

public void await() throws InterruptedException {
    
    
    sync.acquireSharedInterruptibly(1);
}

// 带有超时时间的,最终都会转化成毫秒
public boolean await(long timeout, TimeUnit unit)
    throws InterruptedException {
    
    
    return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

在 AQS 中 acquireSharedInterruptibly 方法最终会调用 doAcquireInterruptibly 方法
在这里插入图片描述
doAcquireInterruptibly

  1. 将当前线程封装成node,然后加入同步队列。
  2. 若当前线程的node非队二且state不等于0,则线程进入阻塞状态。
  3. 通过自旋保证所有被唤醒的线程都能依次恢复运行。
private void doAcquireInterruptibly(int arg)
        throws InterruptedException {
    
    
    	// 将当前线程封装为node,并加到同步队列队尾
        final Node node = addWaiter(Node.EXCLUSIVE);
        boolean failed = true;
        try {
    
    
        	// 自旋,保证所有被唤醒的线程都能依次恢复运行
            for (;;) {
    
    
                final Node p = node.predecessor();
                // 当前node前进到队二 && tryAcquire成功(state减到0),就可以执行了
                if (p == head && tryAcquire(arg)) {
    
    
                    setHead(node);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
    
    
            if (failed)
                cancelAcquire(node);
        }
}

2.2 countDown

线程调用countDown方法后,会将AQS的state-1,若state=0了,就会唤醒所有阻塞在await处的线程。

public void countDown() {
    
    
    sync.releaseShared(1);
}

releaseShared

public final boolean releaseShared(int arg) {
    
    
    // 将state-1,若state=0了,表示当前线程释放锁成功
    if (tryReleaseShared(arg)) {
    
    
        // 唤醒后续节点
        doReleaseShared();
        return true;
    }
    return false;
}

tryReleaseShared

protected boolean tryReleaseShared(int releases) {
    
    
    // 自旋保证 CAS 一定可以成功
    for (;;) {
    
    
        int c = getState();
        // state 已经是 0 了,直接返回 false
        if (c == 0)
            return false;
        // state--
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

doReleaseShared

private void doReleaseShared() {
    
    
    // 自旋,保证所有线程正常的线程都能被唤醒
    for (;;) {
    
    
        Node h = head;
        // 还没有到队尾,此时队列中至少有两个节点
        if (h != null && h != tail) {
    
    
            int ws = h.waitStatus;
            // 如果头结点状态是 SIGNAL ,说明后续节点都需要唤醒
            if (ws == Node.SIGNAL) {
    
    
                // CAS 保证只有一个节点可以运行唤醒的操作
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                // 进行唤醒操作
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        // 退出自旋条件 h==head,一般出现于以下两种情况
        // 第一种情况,头节点没有发生移动,结束。
        // 第二种情况,因为此方法可以被两处调用,一次是获得锁的地方,一处是释放锁的地方,
        // 加上共享锁的特性就是可以多个线程获得锁,也可以释放锁,这就导致头节点可能会发生变化,
        // 如果头节点发生了变化,就继续循环,一直循环到头节点不变化时,结束循环。
        if (h == head)                   // loop if head changed
            break;
    }
}

3.模拟退款示例

  • 小明在淘宝上买了一个商品,觉得不好,把这个商品退掉(商品还没有发货,只退钱),我们叫做单商品退款,单商品退款在后台系统中运行时,整体耗时 30 毫秒。
  • 双 11,小明在淘宝上买了 40 个商品,生成了同一个订单(实际可能会生成多个订单,为了方便描述,我们说成一个),第二天小明发现其中 30 个商品是自己冲动消费的,需要把 30 个商品一起退掉。
// 单商品退款,耗时 30 毫秒,退款成功返回 true,失败返回 false
@Slf4j
public class RefundDemo {
    
    

  /**
   * 根据商品 ID 进行退款
   * @param itemId
   * @return
   */
  public boolean refundByItem(Long itemId) {
    
    
    try {
    
    
      // 线程沉睡 30 毫秒,模拟单个商品退款过程
      Thread.sleep(30);
      log.info("refund success,itemId is {}", itemId);
      return true;
    } catch (Exception e) {
    
    
      log.error("refundByItemError,itemId is {}", itemId);
      return false;
    }
  }
}
@Slf4j
public class BatchRefundDemo {
    
    
  //定义线程池
  public static final ExecutorService EXECUTOR_SERVICE =
      new ThreadPoolExecutor(10, 10, 0L,
                                TimeUnit.MILLISECONDS,
                                new LinkedBlockingQueue<>(20));
  @Test
  public void batchRefund() throws InterruptedException {
    
    
    // state 初始化为 30 
    CountDownLatch countDownLatch = new CountDownLatch(30);
    RefundDemo refundDemo = new RefundDemo();

    // 准备 30 个商品
    List<Long> items = Lists.newArrayListWithCapacity(30);
    for (int i = 0; i < 30; i++) {
    
    
      items.add(Long.valueOf(i+""));
    }

    // 准备开始批量退款
    List<Future> futures = Lists.newArrayListWithCapacity(30);
    for (Long item : items) {
    
    
      // 使用 Callable,因为我们需要等到返回值
      Future<Boolean> future = EXECUTOR_SERVICE.submit(new Callable<Boolean>() {
    
    
        @Override
        public Boolean call() throws Exception {
    
    
          boolean result = refundDemo.refundByItem(item);
          // 每个子线程都会执行 countDown,使 state -1 ,但只有最后一个才能真的唤醒主线程
          countDownLatch.countDown();
          return result;
        }
      });
      // 收集批量退款的结果
      futures.add(future);
    }

    log.info("30 个商品已经在退款中");
    // 使主线程阻塞,一直等待 30 个商品都退款完成,才能继续执行
    countDownLatch.await();
    log.info("30 个商品已经退款完成");
    // 拿到所有结果进行分析
    List<Boolean> result = futures.stream().map(fu-> {
    
    
      try {
    
    
        // get 的超时时间设置的是 1 毫秒,是为了说明此时所有的子线程都已经执行完成了
        return (Boolean) fu.get(1,TimeUnit.MILLISECONDS);
      } catch (InterruptedException e) {
    
    
        e.printStackTrace();
      } catch (ExecutionException e) {
    
    
        e.printStackTrace();
      } catch (TimeoutException e) {
    
    
        e.printStackTrace();
      }
      return false;
    }).collect(Collectors.toList());
    
     // 打印结果统计
    long success = result.stream().filter(r->r.equals(true)).count();
    log.info("执行结果成功{},失败{}",success,result.size()-success);
  }
}

通过以上代码,30 个商品退款完成之后,整体耗时大概在 200 毫秒左右。而通过 for 循环单商品进行退款,大概耗时在 1 秒左右,前后性能相差 5 倍左右,for 循环退款的代码如下:

long begin1 = System.currentTimeMillis();
for (Long item : items) {
    
    
  refundDemo.refundByItem(item);
}
log.info("for 循环单个退款耗时{}",System.currentTimeMillis()-begin1);

文末提出一个问题作为全文的回顾总结。如果一个线程需要等待一组线程全部执行完之后再继续执行,有什么好的办法么?是如何实现的?

答:CountDownLatch 就提供了这样的机制,比如一组线程有 5 个,只需要在初始化 CountDownLatch 时,给同步器的 state 赋值为 5,主线程执行 CountDownLatch.await ,子线程都执行 CountDownLatch.countDown 即可。

猜你喜欢

转载自blog.csdn.net/weixin_43935927/article/details/108718739