[JUC source code] CountDownLatch source code analysis & usage example

CountDownLatch is called a counter in Chinese, and it is also translated as a count lock. Its biggest function is not to lock, but to achieve the function of waiting through counting. There are two main forms of waiting:

  1. Let a group of threads execute together after all the startup is completed (the first started thread needs to block the thread that is started after waiting until all the threads are started, and then execute together);
  2. The main thread waits for the execution of another group of threads to complete before continuing.

1. Structure

The core member variables and main constructors of CountDownLatch are as follows:

public class CountDownLatch {
    
    
	// 从 Sync 的继承关系就可以看出,CountDownLatch也是基于AQS框架实现的
    private static final class Sync extends AbstractQueuedSynchronizer {
    
    
        private static final long serialVersionUID = 4982264981922014374L;
		
		// 构造函数,直接设置state=count
        Sync(int count) {
    
    
            setState(count);
        }
		// 调用AQS方法获取state
        int getCount() {
    
    
            return getState();
        }
		// 能否获取到共享锁。如果当前同步器的状态是 0 的话,表示可获得锁
        protected int tryAcquireShared(int acquires) {
    
    
            return (getState() == 0) ? 1 : -1; // state!=0,就拿锁失败
        }
		// 对 state 进行递减,直到 state 变成 0;state 递减为 0 时,返回 true,其余返回 false
        protected boolean tryReleaseShared(int releases) {
    
    
            // 自旋保证 CAS 一定可以成功
            for (;;) {
    
    
                int c = getState();
                // state 已经是 0 了,直接返回 false
                if (c == 0)
                    return false;
                // state--
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }
    
    private final Sync sync;
    
    //-----------------------------构造函数------------------------------------
    // 无空参构造,必须传入count,count相当于要等待的线程数
    public CountDownLatch(int count) {
    
    
        if (count < 0) throw new IllegalArgumentException("count < 0");
        // 将count传给sync
        this.sync = new Sync(count);
    }
}

Before looking at the source code of the specific method, let's put a simple example of use: a 100-meter race is simulated, 10 runners are ready, just waiting for the referee to give an order. When everyone reaches the finish line, the game ends. It is easy to think that the main thread simulates the referee and opens ten sub-threads to simulate athletes, but there are two problems:

  1. 10 sub-threads must wait until the main thread issues a command (the console prints Game Start) before running
  2. The main thread must wait for the completion of the 10 child threads before exiting

Let's see how to solve these two problems through two CountDownLatch.

public class CountDownLatchTest {
    
    

    public static void main(String[] args) throws InterruptedException {
    
    

        // 开始的倒数锁。count设置为1 是因为主线程只用-1
        final CountDownLatch begin = new CountDownLatch(1);  

        // 结束的倒数锁,count设置为10 是因为有10个子线程都要-1
        final CountDownLatch end = new CountDownLatch(10);  

        // 通过线程池创造出十名选手 (十个线程)
        final ExecutorService exec = Executors.newFixedThreadPool(10);  
		
		// 让这十个线程运行起来
        for (int index = 0; index < 10; index++) {
    
    
            final int NO = index + 1;  // 编号【1,10】
            Runnable run = new Runnable() {
    
    
                public void run() {
    
      
                    try {
    
      
                        // 如果当前计数为零(主线程已就绪),则此方法立即返回。
                        // 如果当前计数不为0(主线程还未调用countDown),等待。
                        begin.await();  
                        Thread.sleep((long) (Math.random() * 10000));  
                        System.out.println("No." + NO + " arrived");  
                    } catch (InterruptedException e) {
    
      
                    } finally {
    
      
                        // 每个选手到达终点(线程执行完毕)时,end就减一
                        end.countDown();
                    }  
                }  
            };  
            exec.submit(run);
        }  
        System.out.println("Game Start");  
        // begin减一,开始游戏
        begin.countDown();  
        // 主线程会阻塞在这里,等待end变为0,即所有选手到达终点
        end.await();  
        System.out.println("Game Over");  
        exec.shutdown();  
    }
}

The console output is as follows:

Game Start
No.9 arrived
No.6 arrived
No.8 arrived
No.7 arrived
No.10 arrived
No.1 arrived
No.5 arrived
No.4 arrived
No.2 arrived
No.3 arrived
Game Over

2. Method analysis & api

CountDownLatch essentially uses the AQS shared lock mode

  1. Pass in count when constructing, and assign count to AQS state
  2. await: Add the current thread to the synchronization queue and sleep.
  3. countDown: state--When state=0, wake up all threads blocked in await to resume operation

Note: Why not use exclusive locks here? Because there may be multiple await positions, you need to wake up all after state=0

2.1 await

public void await() throws InterruptedException {
    
    
    sync.acquireSharedInterruptibly(1);
}

// 带有超时时间的,最终都会转化成毫秒
public boolean await(long timeout, TimeUnit unit)
    throws InterruptedException {
    
    
    return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

In AQS, the acquireSharedInterruptibly method will eventually call the doAcquireInterruptibly method
Insert picture description here
doAcquireInterruptibly

  1. Encapsulate the current thread into a node, and then join the synchronization queue.
  2. If the node of the current thread is not team two and state is not equal to 0, the thread enters the blocking state.
  3. Through the spin, it is guaranteed that all awakened threads can resume running in turn.
private void doAcquireInterruptibly(int arg)
        throws InterruptedException {
    
    
    	// 将当前线程封装为node,并加到同步队列队尾
        final Node node = addWaiter(Node.EXCLUSIVE);
        boolean failed = true;
        try {
    
    
        	// 自旋,保证所有被唤醒的线程都能依次恢复运行
            for (;;) {
    
    
                final Node p = node.predecessor();
                // 当前node前进到队二 && tryAcquire成功(state减到0),就可以执行了
                if (p == head && tryAcquire(arg)) {
    
    
                    setHead(node);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
    
    
            if (failed)
                cancelAcquire(node);
        }
}

2.2 countDown

After the thread calls the countDown method, it will AQS state-1, if state=0, it will wake up all threads blocked in await.

public void countDown() {
    
    
    sync.releaseShared(1);
}

releaseShared

public final boolean releaseShared(int arg) {
    
    
    // 将state-1,若state=0了,表示当前线程释放锁成功
    if (tryReleaseShared(arg)) {
    
    
        // 唤醒后续节点
        doReleaseShared();
        return true;
    }
    return false;
}

tryReleaseShared

protected boolean tryReleaseShared(int releases) {
    
    
    // 自旋保证 CAS 一定可以成功
    for (;;) {
    
    
        int c = getState();
        // state 已经是 0 了,直接返回 false
        if (c == 0)
            return false;
        // state--
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

doReleaseShared

private void doReleaseShared() {
    
    
    // 自旋,保证所有线程正常的线程都能被唤醒
    for (;;) {
    
    
        Node h = head;
        // 还没有到队尾,此时队列中至少有两个节点
        if (h != null && h != tail) {
    
    
            int ws = h.waitStatus;
            // 如果头结点状态是 SIGNAL ,说明后续节点都需要唤醒
            if (ws == Node.SIGNAL) {
    
    
                // CAS 保证只有一个节点可以运行唤醒的操作
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                // 进行唤醒操作
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        // 退出自旋条件 h==head,一般出现于以下两种情况
        // 第一种情况,头节点没有发生移动,结束。
        // 第二种情况,因为此方法可以被两处调用,一次是获得锁的地方,一处是释放锁的地方,
        // 加上共享锁的特性就是可以多个线程获得锁,也可以释放锁,这就导致头节点可能会发生变化,
        // 如果头节点发生了变化,就继续循环,一直循环到头节点不变化时,结束循环。
        if (h == head)                   // loop if head changed
            break;
    }
}

3. Example of simulated refund

  • Xiao Ming bought a product on Taobao and felt it was not good, so he returned the product (the product has not been shipped yet, only the money is refunded). We call it a single product refund. When a single product refund runs in the background system, the overall consumption Time is 30 milliseconds.
  • On Double 11, Xiao Ming bought 40 items on Taobao and generated the same order (in fact, multiple orders may be generated, but for the sake of description, we will say one). The next day, Xiao Ming found out that 30 of the items were consumed impulsively. Yes, you need to return 30 items together.
// 单商品退款,耗时 30 毫秒,退款成功返回 true,失败返回 false
@Slf4j
public class RefundDemo {
    
    

  /**
   * 根据商品 ID 进行退款
   * @param itemId
   * @return
   */
  public boolean refundByItem(Long itemId) {
    
    
    try {
    
    
      // 线程沉睡 30 毫秒,模拟单个商品退款过程
      Thread.sleep(30);
      log.info("refund success,itemId is {}", itemId);
      return true;
    } catch (Exception e) {
    
    
      log.error("refundByItemError,itemId is {}", itemId);
      return false;
    }
  }
}
@Slf4j
public class BatchRefundDemo {
    
    
  //定义线程池
  public static final ExecutorService EXECUTOR_SERVICE =
      new ThreadPoolExecutor(10, 10, 0L,
                                TimeUnit.MILLISECONDS,
                                new LinkedBlockingQueue<>(20));
  @Test
  public void batchRefund() throws InterruptedException {
    
    
    // state 初始化为 30 
    CountDownLatch countDownLatch = new CountDownLatch(30);
    RefundDemo refundDemo = new RefundDemo();

    // 准备 30 个商品
    List<Long> items = Lists.newArrayListWithCapacity(30);
    for (int i = 0; i < 30; i++) {
    
    
      items.add(Long.valueOf(i+""));
    }

    // 准备开始批量退款
    List<Future> futures = Lists.newArrayListWithCapacity(30);
    for (Long item : items) {
    
    
      // 使用 Callable,因为我们需要等到返回值
      Future<Boolean> future = EXECUTOR_SERVICE.submit(new Callable<Boolean>() {
    
    
        @Override
        public Boolean call() throws Exception {
    
    
          boolean result = refundDemo.refundByItem(item);
          // 每个子线程都会执行 countDown,使 state -1 ,但只有最后一个才能真的唤醒主线程
          countDownLatch.countDown();
          return result;
        }
      });
      // 收集批量退款的结果
      futures.add(future);
    }

    log.info("30 个商品已经在退款中");
    // 使主线程阻塞,一直等待 30 个商品都退款完成,才能继续执行
    countDownLatch.await();
    log.info("30 个商品已经退款完成");
    // 拿到所有结果进行分析
    List<Boolean> result = futures.stream().map(fu-> {
    
    
      try {
    
    
        // get 的超时时间设置的是 1 毫秒,是为了说明此时所有的子线程都已经执行完成了
        return (Boolean) fu.get(1,TimeUnit.MILLISECONDS);
      } catch (InterruptedException e) {
    
    
        e.printStackTrace();
      } catch (ExecutionException e) {
    
    
        e.printStackTrace();
      } catch (TimeoutException e) {
    
    
        e.printStackTrace();
      }
      return false;
    }).collect(Collectors.toList());
    
     // 打印结果统计
    long success = result.stream().filter(r->r.equals(true)).count();
    log.info("执行结果成功{},失败{}",success,result.size()-success);
  }
}

Through the above code, after the refund of 30 products is completed, the overall time-consuming is about 200 milliseconds. However, it takes about 1 second to refund a single product through the for loop, and the performance difference is about 5 times. The code for the for loop refund is as follows:

long begin1 = System.currentTimeMillis();
for (Long item : items) {
    
    
  refundDemo.refundByItem(item);
}
log.info("for 循环单个退款耗时{}",System.currentTimeMillis()-begin1);

At the end of the article, a question is raised as a review and summary of the full text. If a thread needs to wait for a group of threads to execute before continuing, is there any good way? How is it achieved?

Answer: CountDownLatch provides such a mechanism. For example, if there are 5 threads in a group, you only need to assign 5 to the state of the synchronizer when initializing CountDownLatch, the main thread executes CountDownLatch.await, and the child threads execute CountDownLatch.countDown. .

Guess you like

Origin blog.csdn.net/weixin_43935927/article/details/108718739