CAS principle and ABA problem analysis of java concurrency cornerstone

This article has participated in the "Newcomer Creation Ceremony" event to start the road of gold creation together.

1. Introduce CAS with a small case

主要内容:
1. 从网站计数器实现中一步步引出CAS操作
2. 介绍JAVA中的CAS及CAS可能存在的问题
复制代码

First implement a small demo to understand what CAS is. Requirements: When we develop a website, we need to count the number of visits. Every time a user sends a request, the number of visits is +1. How to achieve this? We simulate that there are 100 people visiting at the same time, and each person sends 10 requests to the website, and finally the total number of visits should be 100 times.

package CAS;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/**
 * 网站访客统计Demo
 */
public class Demo01 {
    //网站总访问量
    volatile static int count = 0;			//加了volatile保证count变量对于所有线程来说是可见的
    public static  void request() throws InterruptedException {
        //耗时5毫秒
        TimeUnit.MILLISECONDS.sleep(5);
        count++;            //访问量++,这里count++并不是原子操作
    }
    public static void main(String[] args) throws InterruptedException {
        //开始时间
        long startTime = System.currentTimeMillis();
        //最大线程数,模拟100个线程同时访问
        int threadSize = 100;

        CountDownLatch countDownLatch = new CountDownLatch(threadSize);

        for(int i = 0; i < threadSize; i++){
            new Thread(() -> {
                    try {
                        for(int j = 0; j < 10; j++){
                            request();
                        }
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } finally {
                        countDownLatch.countDown();
                    }
            }).start();
        }
        countDownLatch.await();
        // 100个线程询问时间
        //为了保证下面的语句在所有线程执行完之后再执行,我们使用CountDownLatch来控制
        long endTime = System.currentTimeMillis();
        System.out.println(Thread.currentThread().getName() + ", 耗时:" + (endTime - startTime) + ", count:" + count);
    }
}
复制代码

The result is as follows:

在这里插入图片描述

We can see that the result is not correct, so why?

那是因为count++并不是一个原子操作,它其实可以分为3步
1. 获取count的值 记作A   A=count
2. 将A的值加1,得到B     B=A+1 
3. 将B的值赋给count


结果不正确的原因是假如两个线程同时执行到上面步骤的第一步,那么这两个线程执行完后count的值只加了1,
但其实应该加2,结果不正确。


如何解决呢?
我们可以让线程在执行count++时串行执行,也就是排队执行,在一个线程在执行count++操作时,
其他线程必须排队等待该线程执行完后才可以执行count++操作。
那么如何实现排队效果呢?
我们很容易就可以想到,可以使用synchronized关键字或者ReentrantLock锁来实现。
复制代码

Use synchronized to solve concurrency problems.

package CAS;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/**
 * 网站访客统计Demo
 */
public class Demo02 {
    //网站总访问量
    volatile static int count = 0;				//加了volatile保证count变量对于所有线程来说是可见的
    public static  synchronized void request() throws InterruptedException {
        //耗时5毫秒
        TimeUnit.MILLISECONDS.sleep(5);
        count++;            //访问量++,这里count++并不是原子操作
    }
    public static void main(String[] args) throws InterruptedException {
        //开始时间
        long startTime = System.currentTimeMillis();
        //最大线程数,模拟100个线程同时访问
        int threadSize = 100;

        //设置初值为100,表示100个线程执行完后,countDownLatch.await()处的线程才可以继续执行
        CountDownLatch countDownLatch = new CountDownLatch(threadSize);

        for(int i = 0; i < threadSize; i++){
            new Thread(() -> {
                    try {
                        for(int j = 0; j < 10; j++){
                            request();
                        }
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } finally {
                        countDownLatch.countDown();             //countDownLatch--
                    }
            }).start();
        }
        countDownLatch.await();
        // 100个线程询问时间
        long endTime = System.currentTimeMillis();
        System.out.println(Thread.currentThread().getName() + ", 耗时:" + (endTime - startTime) + ", count:" + count);
    }
}
复制代码

在这里插入图片描述

We can find that the result is correct, but the time-consuming is greatly increased compared to not adding synchronization (because we lock the method ), so can we optimize the time while ensuring the correct result? ?

How to optimize

As we know above, the count++ operation is divided into 3 steps

  1. Get the value of count, denoted as A, A = count
  2. Add 1 to the value of A and denote it as B, B = A + 1
  3. Assign the value of B to count

Can we only lock the third step , so that our time is not optimized?

package CAS;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

public class Demo03 {
    // 网站总访问量:volatile保证线程可见性,便于在下面逻辑中 -> 保证多线程之间每次获取到的count是最新值
    volatile static int count = 0;

    // 模拟访问的方法
    public static void request() throws InterruptedException {
        // 模拟耗时5毫秒
        TimeUnit.MILLISECONDS.sleep(5);

        //count ++;

        int expectCount; // 表示期望值
        // 比较并交换
        while (!compareAndSwap((expectCount = getCount()), expectCount + 1)) {
        }
    }

    /**
     * 比较并交换
     *
     * @param expectCount 期望值count
     * @param newCount    需要给count赋值的新值
     * @return 成功返回 true 失败返回false
     */
    public static synchronized boolean compareAndSwap(int expectCount, int newCount) {
        // 判断count当前值是否和期望值expectCount一致,如果一致 将newCount赋值给count
        if (getCount() == expectCount) {
            count = newCount;
            return true;
        }
        return false;
    }

    public static int getCount() {
        return count;
    }

    public static void main(String[] args) throws InterruptedException {
        // 开始时间
        long startTime = System.currentTimeMillis();
        int threadSize = 100;
        CountDownLatch countDownLatch = new CountDownLatch(threadSize);

        for (int i = 0; i < threadSize; i++) {

            Thread thread = new Thread(new Runnable() {
                @Override
                public void run() {
                    // 模拟用户行为,每个用户访问10次网站
                    try {
                        for (int j = 0; j < 10; j++) {
                            request();
                        }
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } finally {
                        countDownLatch.countDown();
                    }
                }
            });

            thread.start();
        }
        // 保证100个线程 结束之后,再执行后面代码
        countDownLatch.await();
        long endTime = System.currentTimeMillis();

        System.out.println(Thread.currentThread().getName() + ",耗时:" + (endTime - startTime) + ", count = " + count);
    }
}
复制代码

在这里插入图片描述

We can find that not only the results are correct , but also the time-consuming is very small , which has played a great role in optimization.

The comparison and exchange we used above, and the thread-safe way is CAS

2. CAS implementation principle

CASE

  • The full name of CAS is "CompareAndSwap", which is translated from Chinese as "compare and replace".

Definition :

  • CAS操作包含三个操作数——内存位置(V)、期望值(A)和新值(B)。
  • 如果内存位置中的数和期望值相同,处理器就将内存位置处的数更新为B,否则,不做任何操作。
  • 无论哪种情况,它都会在CAS指令之前返回当前位置上的数。(CAS在一些特殊情况下仅返回CAS是否
  • 成功,而不提取当前值)CAS有效的说明了“我认为当前值V应该包含A,如果包含该值,则将B放到该
  • 位置上,否则,不要更改该位置的值,只需要告诉我该位置上的数是几即可。”

2.1 怎么使用JDK支持的CAS操作

java中提供了对CAS操作的支持,具体在sum.misc.unsafe类中,声明如下:


//下面这三个方法分别为要修改的对象的属性为Object、Int、Long类型时使用的方法
public final native boolean compareAndSwapObject(Object var1, long var2, Object var4, Object var5);

public final native boolean compareAndSwapInt(Object var1, long var2, int var4, int var5);

public final native boolean compareAndSwapLong(Object var1, long var2, long var4, long var6);
复制代码
  • 参数var1:表示要操作的对象

  • 参数var2:表示要在操作对象中属性的偏移量 //我们知道对象在堆中有一个地址,对象的属性在堆中也有一个地址,这个偏移量就是

    ​ 要操作的对象的属性相较于对象的偏移量。

  • 参数var4:表示要修改的数据(对象的属性)的期望的值

  • 参数var5:表示需要修改为的新值

2.2 CAS实现的原理

# CAS实现的原理是什么

CAS通过调用JNI的代码实现,JNI:java Native Interface,允许java调用
其他语言。而compareAndSwapxxx系列的方法就是借助“C语言”来调用cpu底层的
指令实现的。
以常用的Intel x86来说,最终映射到的CPU的指令为“cmpxchg”,这是一个原子指令,
cpu执行此命令时,实现比较并替换的操作。

# 现代计算机动不动就上百核心,cmpxchg怎么保证多线程下的线程安全

系统底层进行CAS操作的时候,会判断当前系统是否为多核心系统,如果是,
就给总线加锁,只有一个线程会对总线加锁成功,加锁成功之后会执行CAS操作,
也就是说CAS的原子性是平台级别的(同一时刻只能有一个线程执行CAS操作)。

复制代码

3. ABA问题

什么是ABA问题

CAS需要在操作值的时候检查下值有没有发生变化,如果没有发生变化就更新,但是如果一个值原来是A,在CAS操作之前,被其他线程修改为了B,然后又修改回了A,那么CAS方法执行检查的时候会发现它的值没有发生变化,但是实际却变化了,这就是CAS的ABA问题。

模拟ABA问题

//这是AtomicInteger原子类的交换并比较源码,我们可以看出它是调用unsafe类中的方法,说明是调用底层,是线程安全的
public final boolean compareAndSet(int expect, int update) {
    return unsafe.compareAndSwapInt(this, valueOffset, expect, update);
}
复制代码
package CAS;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;

public class Demo04 {
    static AtomicInteger a = new AtomicInteger(1);

    public static void main(String[] args) {
        Thread main = new Thread(() -> {
            System.out.println("操作线程:" + Thread.currentThread().getName() + ", 初始值:" + a.get());
            try {
                int expectCount = a.get();
                int newCount = a.get() + 1;
                Thread.sleep(1000);         //main线程沉睡1秒,让出cpu
                boolean isCASSuccess = a.compareAndSet(expectCount, newCount);
                System.out.println("操作线程:" + Thread.currentThread().getName() + ", CAS操作:" + isCASSuccess);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }, "main");

        Thread other = new Thread(() -> {
            try {
                Thread.sleep(20);       //确保thread-main线程优先执行

                a.incrementAndGet();        //a+1, a = 2;
                System.out.println("操作线程:" + Thread.currentThread().getName() + ", 【increment】,值=" + a.get());

                a.decrementAndGet();        //a-1, a = 1;
                System.out.println("操作线程:" + Thread.currentThread().getName() + ", 【decrement】,值=" + a.get());

            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }, "干扰线程");
        main.start();
        other.start();
    }
}
复制代码

在这里插入图片描述

我们可以看出,当我们使用AtomicInteger原子类的ABA操作时,并不能够解决ABA问题,那么我们如何解决ABA问题呢?

# 如何解决ABA问题
解决ABA最简单的方案就是给值加一个修改版本号,每次值变化,都还修改它的版本号,
CAS操作时都去对比此版本号。


# java中ABA解决方法(AtomicStampedReference)
AtomicStampedReference主要包含一个对象引用及一个可以自动更新的整数“stamp”的pair
来解决ABA问题。
复制代码

我们看一下AtomicStampedReference类的部分源码

public class AtomicStampedReference<V> {

    private static class Pair<T> {
        final T reference;
        final int stamp;
        //我们可以看出AtomicStampedReference存到的是一个pair,一个存的是对象的引用,一个是序列版本号
        private Pair(T reference, int stamp) {
            this.reference = reference;					//对象的引用		
            this.stamp = stamp;							//版本号
        }
        static <T> Pair<T> of(T reference, int stamp) {
            return new Pair<T>(reference, stamp);
        }
    }

    private volatile Pair<V> pair;
    public boolean compareAndSet(V   expectedReference,			//期望引用
                             V   newReference,				//新值引用
                             int expectedStamp,				//期望引用的版本号
                             int newStamp) {				//新值引用的版本号
    Pair<V> current = pair;
    return
        expectedReference == current.reference &&			//期望引用与当前引用一致
        expectedStamp == current.stamp &&					//期望版本与当前版本一致
        ((newReference == current.reference &&				//新值引用等于当前引用
          newStamp == current.stamp) ||						//新值版本等于当前版本
         //新值引用等于当前引用&&新值版本等于当前版本就无需创建新的pair否则创建新的pair
         casPair(current, Pair.of(newReference, newStamp)));
	}
    ······
}
复制代码

下面我们利用AtomicStampedReference类解决ABA问题

package CAS;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.AtomicStampedReference;
import java.util.concurrent.locks.ReentrantLock;

public class Demo05 {
    static AtomicStampedReference<Integer> a = new AtomicStampedReference(new Integer(1), 1);

    public static void main(String[] args) {
        Thread main = new Thread(() -> {
            System.out.println("操作线程:" + Thread.currentThread().getName() + ", 初始值:" + a.getReference());
            try {
                Integer expectReference = a.getReference();
                Integer newReference = expectReference + 1;
                Integer expectStamp = a.getStamp();
                Integer newStamp = expectStamp + 1;
                Thread.sleep(1000);         //main线程沉睡1秒,让出cpu
                boolean isCASSuccess = a.compareAndSet(expectReference, newReference, expectStamp, newStamp);
                System.out.println("操作线程:" + Thread.currentThread().getName() + ", CAS操作:" + isCASSuccess);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }, "main");

        Thread other = new Thread(() -> {
            try {
                Thread.sleep(20);       //确保thread-main线程优先执行

                a.compareAndSet(a.getReference(), a.getReference() + 1, a.getStamp(), a.getStamp() + 1);
                System.out.println("操作线程:" + Thread.currentThread().getName() + ", 【increment】,值=" + a.getReference());

                a.compareAndSet(a.getReference(), a.getReference() - 1, a.getStamp(), a.getStamp() + 1);
                System.out.println("操作线程:" + Thread.currentThread().getName() + ", 【decrement】,值=" + a.getReference());

            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }, "干扰线程");
        main.start();
        other.start();
    }
}
复制代码

在这里插入图片描述

我们可以看到,AtomicStampedReference类解决了ABA问题

Guess you like

Origin juejin.im/post/7084713323802918926