利用cas+单向队列实现一个可重入的轻量级的非公平的排他锁,参考部分ReentrantLock源码

实现思路,定义一个state表示锁的状态,state=0,表示锁可以被获取,state>0,表示锁正在被当前线程或者其他线程占用,获取锁的方法,采用cas将state设置为1,如果锁被当前线程占用,再次获取,只需要++state,如果获取不到锁,就将当前线程加入到等待队列中,在释放锁的时候,如果state=0,就通知等待队列中第一个等待的线程取获取锁。具体代码如下:

import sun.misc.Unsafe;
import java.lang.reflect.Field;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.LockSupport;

/**
 * 仿照ReentrantLock原理手动实现一个轻量级排他锁,为了简单起见,这里只是实现lock与unlock方法
 * 关键词:重入锁,排他锁,轻量级锁,非公平性锁,独占锁
 */
public class MyReentrantLock implements Lock {

    /**
     * 同步器
     */
    public final Sync sync;

    /**
     * 构造方法
     */
    public MyReentrantLock() {
        this.sync = new Sync();
    }

    /**
     * 加锁方法
     */
    public void lock() {
        sync.lock(1);
    }

    /**
     * 释放锁的方法
     */
    public void unlock() {
        sync.unlock();
    }

    /**
     * 同步工具类作为内部辅助类
     */
    static class Sync extends AbstractQueuedSynchronizer {
        /**
         * 锁的状态:
         * 0:表示锁没有被其他线程占用
         * >0:表示锁已经被当前线程或者其他线程占用
         * 当前线程占用锁,调用lock方法,表示锁的重入 state+1
         * 释放锁的时候ctate-1
         * 初始值为0
         */
        public volatile int state = 0;
        /**
         * state的内存偏移地址
         */
        private static final long stateOffset;
        /**
         * 引入UNSAFE对象,目的是使用其cas方法
         */
        private static final sun.misc.Unsafe UNSAFE;
        /**
         * 当前占用锁的线程
         */
        public volatile Thread ownerThread = null;
        /**
         * 等待队列头节点
         */
        public volatile Node head;
        /**
         * 等待队列尾节点
         */
        public volatile Node tail;
        /**
         * head节点的内存偏移地址
         */
        private static final long headOffset;
        /**
         * tail节点的内存偏移地址tail
         */
        private static final long tailOffset;

        public Sync() {
            head = new Node(null);
            tail = new Node(null);
            head.next=tail;
        }

        static {
            try {
                Field f = Unsafe.class.getDeclaredField("theUnsafe");
                f.setAccessible(true);
                UNSAFE = (Unsafe) f.get(null);
                Class<?> k = Sync.class;
                stateOffset = UNSAFE.objectFieldOffset
                        (k.getDeclaredField("state"));
                headOffset = UNSAFE.objectFieldOffset
                        (k.getDeclaredField("head"));
                tailOffset = UNSAFE.objectFieldOffset
                        (k.getDeclaredField("tail"));
            } catch (Exception e) {
                throw new Error(e);
            }
        }

        /**
         * 尝试获取锁
         * 获取成功,返回true;
         * 获取失败,返回false;
         *
         * @param arg
         * @return
         */
        @Override
        protected boolean tryAcquire(int arg) {
            //获取当前线程
            Thread thread = Thread.currentThread();
            //state为0,表示锁处于空闲状态,可以被获取,利用cas获取成功后,直接返回true;
            if (state == 0) {
                if (compareAndSetStateValue(0, 1)) {
                    ownerThread = thread;
                    return true;
                }
            }
            //当前线程占用该锁,此时只需要将state状态值加1即可,表示锁的重入
            else if (thread == ownerThread) {
                //注意,此步操作不需要cas的原因是,锁被当前线程占用,state不会被其他线程修改,故不存在线程安全性问题
                ++state;
                return true;
            }
            //其他线程占用锁,直接返回false
            return false;
        }

        /**
         * 只有内存中的值与预期值i相同的时候,才会将内存中的值更新为arg,此步操作为原子操作
         *
         * @param i   预期值
         * @param arg 需要更新的值
         * @return
         */
        private boolean compareAndSetStateValue(int i, int arg) {
            return UNSAFE.compareAndSwapInt(this, stateOffset, i, arg);
        }

        /**
         * 尝试释放锁
         *
         * @param arg
         * @return
         */
        @Override
        protected boolean tryRelease(int arg) {
            //释放锁的时候,通知同步队列中等待的线程取获取锁
            if(state==0){
                ownerThread=null;
                Node next = head.next;
                if (null != next) {
                    next.nodeState = 1;
                    LockSupport.unpark(next.thread);
                }
            }
            return true;
        }

        /**
         * 获取锁的方法
         */
        public void lock(int arg) {
            //获取锁不成功,就将当前线程加入等待队列中(添加到队尾)当前线程会等待被叫醒
            if(!tryAcquire(arg)){
                addWaitQueue();
            }
        }

        /**
         * 节点加入队列的方法
         */
        private void addWaitQueue() {
            //创建一个节点
            Node node = new Node(Thread.currentThread());
            //将当前节点添加到尾节点之后,并设置当前节点为新的尾节点
            addWaitTail(node);
            for (; ; ) {
                Node headNext = head.next;
                //如果head的下一个节点保存的是当前线程,并且当前节点状态为1,表示可以获取锁
                if (node.thread== headNext.thread && headNext.nodeState == 1) {
                    if (tryAcquire(1)) {
                        ownerThread = Thread.currentThread();
                        //设置当前节点为新的头节点
                        headNext.nodeState = 0;
                        head = headNext;
                        return;
                    }
                }
                //当前线程park
                LockSupport.park();
            }

        }

        private void addWaitTail(Node node) {
            for (; ; ) {
                Node last = tail;
                //尾节点中没有保存线程,直接将当前线程保存到尾节点中
                if(last.thread==null){
                    if(tail.casTailThread(null,Thread.currentThread())){
                        break;
                    }
                }
                //尾节点中保存有线程,将该节点添加到尾节点之后,并把该节点设置为新的尾节点
                Node next = last.next;
                if (null == next) {
                    if (tail.casNext(null, node)) {
                        tail=node;
                        break;
                    }
                }

            }
        }

        private void casTail(Node last, Node next) {
            UNSAFE.compareAndSwapObject(this, tailOffset, last, next);
        }

        public void unlock() {
            --state;
            tryRelease(1);
        }
    }

    /**
     * 内部节点类
     */
    static class Node {
        /**
         * 节点保存的线程
         */
        public Thread thread;
        /**
         * 下一个节点
         */
        public volatile Node next;
        /**
         * 节点的状态
         */
        public volatile int nodeState;
        /**
         * 引入UNSAFE对象,目的是使用其cas方法
         */
        private static final sun.misc.Unsafe UNSAFE;
        /**
         * 下一个节点的内存偏移地址
         */
        private static final long nextOffset;

        /**
         * 保存的线程的内存偏移地址
         */
        private static final long threadOffsset;

        public Node(Thread thread) {
            this.thread = thread;
        }


        static {
            try {
                Field f = Unsafe.class.getDeclaredField("theUnsafe");
                f.setAccessible(true);
                UNSAFE = (Unsafe) f.get(null);
                Class<?> k = Node.class;
                nextOffset = UNSAFE.objectFieldOffset
                        (k.getDeclaredField("next"));
                threadOffsset = UNSAFE.objectFieldOffset
                        (k.getDeclaredField("thread"));
            } catch (Exception e) {
                throw new Error(e);
            }
        }

        public boolean casNext(Object o, Node node) {
            return UNSAFE.compareAndSwapObject(this, nextOffset, o, node);
        }

        public boolean casTailThread(Object o, Thread currentThread) {
            return UNSAFE.compareAndSwapObject(this,threadOffsset,o,currentThread);
        }
    }

    public void lockInterruptibly() throws InterruptedException {

    }

    public boolean tryLock() {
        return false;
    }

    public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
        return false;
    }

    public Condition newCondition() {
        return null;
    }
}

测试方法如下,开启2000个线程,每个线程循环10次往集合中添加元素,这里使用并发工具类CyclicBarrier只是为了更好的保证线程能够并发的执行,经过测试,最后集合的大小尾20000,说明该锁生效

import java.util.LinkedList;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;

public class TestMyReentrantLock {
    private static LinkedList<Integer> list = new LinkedList<Integer>();
    public static void main(String[] args) {
        final CyclicBarrier barrier = new CyclicBarrier(2000);
        final MyReentrantLock lock = new MyReentrantLock();
        for (int i = 0; i<2000;i++){
            new Thread(new Runnable() {
                public void run() {
                    try {
                        //保证2000个线程并发执行
                        barrier.await();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } catch (BrokenBarrierException e) {
                        e.printStackTrace();
                    }
                    for(int i=0;i<10;i++){
                        lock.lock();
                        list.add(i);
                        lock.unlock();
                    }

                }
            }).start();
        }
        //阻塞等待所有线程执行完毕
        while(Thread.activeCount()!=2){ }
        System.out.println("list的大小为:"+list.size());
        System.out.println(list);

    }
}
发布了2 篇原创文章 · 获赞 3 · 访问量 98

猜你喜欢

转载自blog.csdn.net/WKzhangliang123/article/details/104767015