分布式锁-基于ZK和Redis实现

一、基于zookeeper实现分布式锁

1.1 Zookeeper的常用接口

package register;


import java.util.List;
import java.util.concurrent.CountDownLatch;
import org.apache.zookeeper.CreateMode;
import org.apache.zookeeper.KeeperException;
import org.apache.zookeeper.WatchedEvent;
import org.apache.zookeeper.Watcher;
import org.apache.zookeeper.Watcher.Event.KeeperState;
import org.apache.zookeeper.ZooDefs.Ids;
import org.apache.zookeeper.ZooKeeper;
import org.apache.zookeeper.data.Stat;

public class BaseZookeeper implements Watcher{

    public BaseZookeeper(){}

    public BaseZookeeper(String host){
        this.connectZookeeper(host);
    }

    private ZooKeeper zookeeper;

    //超时时间
    private static final int SESSION_TIME_OUT = 2000;
    private CountDownLatch countDownLatch = new CountDownLatch(1);

    public void process(WatchedEvent event) {
        if (event.getState() == KeeperState.SyncConnected) {
            //System.out.println("Watch received event");
            countDownLatch.countDown();
        }
    }

    //连接zookeeper
    protected void connectZookeeper(String host){
        try {
            zookeeper = new ZooKeeper(host, SESSION_TIME_OUT, this);
            countDownLatch.await();
            //System.out.println("zookeeper connection success");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    //创建节点
    protected String createNode(String path, String data){
        try {
            //永久节点
            String result = this.zookeeper.create(path, data.getBytes(), Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
            //临时节点(会话关闭就删除了,调用close后就自动删除了)
            //String result = this.zookeeper.create(path, data.getBytes(), Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL);
            System.out.println("createNode: " + result);
            return result;
        } catch (Exception e) {
            //e.printStackTrace();
            return null;
        }
    }

    //创建多级节点
    //String path = "/dubbo/server/com.wzy.server.OrderServer";
    protected boolean createMultNode(String path){

        String[] paths = path.split("/");
        String realPath = "/";
        for (int i=1; i<paths.length; i++) {
            realPath += paths[i];
            String result = createNode(realPath, "");

            if (result == null) {
                return false;
            }
            realPath += "/";
        }
        return true;
    }

    //获取路径下所有子节点
    protected List<String> getChildren(String path){
        try {
            List<String> children = zookeeper.getChildren(path, false);
            return children;
        } catch (Exception e) {
            //当路径已经是根节点(没有子节点)时,就会抛异常
            return null;
        }

    }

    //获取节点上面的数据
    protected String getData(String path) throws KeeperException, InterruptedException{
        byte[] data = zookeeper.getData(path, false, null);
        if (data == null) {
            return "";
        }
        return new String(data);
    }

    //设置节点信息
    protected Stat setData(String path, String data){
        try {
            getData(path);
            Stat stat = zookeeper.setData(path, data.getBytes(), -1);
            return stat;
        } catch (Exception e) {
            //String result = createNode(path,"");
            return null;
        }

    }

    //删除节点
    protected boolean deleteNode(String path){
        if (!path.startsWith("/")) {
            path = "/" + path;
        }
        try {
            zookeeper.delete(path, -1);
        } catch (InterruptedException e) {
            return false;
        } catch (KeeperException e) {
            return false;
        }
        return true;
    }

    //获取创建时间
    protected String getCTime(String path) throws KeeperException, InterruptedException{
        Stat stat = zookeeper.exists(path, false);
        return String.valueOf(stat.getCtime());
    }

    //获取某个路径下孩子的数量
    protected Integer getChildrenNum(String path) throws KeeperException, InterruptedException{
        int childenNum = zookeeper.getChildren(path, false).size();
        return childenNum;
    }

    //监听节点是否被删除
    protected void watchIsDel(final String path) throws Exception{
        zookeeper.exists(path, new Watcher() {
            public void process(WatchedEvent watchedEvent) {
                Event.EventType type = watchedEvent.getType();
                if (Event.EventType.NodeDeleted.equals(type)) {
                    System.out.println("结点 " + path + "被删除了");
                }
            }
        });
    }

    //关闭连接
    public void closeConnection() {
        if (zookeeper != null) {
            try {
                zookeeper.close();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

}
package register;

import framework.URL;

import java.util.List;
import java.util.Random;

/**
 * zk 的注册工具
 */
public class ZkRegister extends BaseZookeeper {

    private static String ZK_HOST = "127.0.0.1:2181";
    private final String SERVER_ADDRESS = "/dubbo/server";
    private final String ROOT_ADDRESS = "/dubbo";

    public ZkRegister(){
        super(ZK_HOST);
    }

    public void setZkHost(String host){ZkRegister.ZK_HOST = host;}

    /**
     * 注册服务
     * @param serverInterface
     * @param url
     * @return
     */
    public boolean regist(Class serverInterface, URL url){
        if (null != getChildren(ROOT_ADDRESS)){
            deleteNodeRF(ROOT_ADDRESS);
        }
        return addAddressToNode(SERVER_ADDRESS + "/" + serverInterface.getName(), new String[]{url.getAddress()});
    }

    /**
     * 从地址列表里随机获取一个地址
     * @param serverInterface
     * @return
     */
    public String getURLRandom(Class serverInterface){
        List<String> urls = getChildren(SERVER_ADDRESS + "/" + serverInterface.getName());
        return urls.get(new Random().nextInt(urls.size()));
    }

    /**
     * 向节点添加服务地址
     * @param nodePath
     * @param address
     * @return
     * String path = "/dubbo/server/com.wzy.server.OrderServer";
     * String[] ip = new String[]{"192.168.37.1","192.168.37.2","192.168.37.3"};
     */
    public boolean addAddressToNode (String nodePath, String[] address) {
        if (!nodePath.startsWith("/")) {
            nodePath = "/" + nodePath;
        }

        if (null == getChildren(nodePath)){
            createMultNode(nodePath);
        }
        for (int i=0; i<address.length; i++) {
            String newPath = nodePath + "/" + address[i];
            String result = createNode(newPath,"");
            if (null == result) {
                return false;
            }
        }
        return true;
    }

    public boolean deleteNodeRF (String rootPath) {
        return deleteNodeRF(rootPath, rootPath);
    }
    /**
     * 删除节点及其子目录
     * @param rootPath
     * @return
     */
    private boolean deleteNodeRF (String rootPath, String parentPath) {
        if (!rootPath.startsWith("/")) {
            rootPath = "/" + rootPath;
        }
        List<String> childs = getChildren(rootPath);
        if (childs.size() > 0) {
            //递归
            for (String child : childs) {
                deleteNodeRF(rootPath + "/" + child, rootPath);
            }
        } else {
            System.out.println("delete: " + rootPath + " " + deleteNode(rootPath));
        }
        System.out.println("delete: " + parentPath + " " + deleteNode(parentPath));

        return true;
    }
}

1.2 基于zk实现分布式锁

package lock;

import org.apache.zookeeper.*;

import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;

/**
 * Zookeeper实现分布式锁
 */
public class ZookeeperLock implements Lock {

    private ThreadLocal<ZooKeeper> zk = new ThreadLocal<ZooKeeper>();
    private String host = "localhost:2181";

    private final String LOCK_NAME = "/LOCK";
    private ThreadLocal<String> CURRENT_NODE = new ThreadLocal<String>();

    private void init() {
        if (zk.get() == null) {
            synchronized (ZookeeperLock.class) {
                if (zk.get() == null) {
                    try {
                        zk.set( new ZooKeeper(host, 2000, new Watcher() {
                            public void process(WatchedEvent watchedEvent) {
                                // do nothing..
                            }
                        }));
                    } catch (Exception e) {
                        e.printStackTrace();
                    }

                }
            }
        }
    }

    public void lock() {
        init();
        if (tryLock()) {
            System.out.println("get lock success");
        }
    }

    public boolean tryLock() {
        String node = LOCK_NAME + "/zk_";
        try {
            //创建临时顺序节点  /LOCK/zk_1
            CURRENT_NODE.set(zk.get().create(node, new byte[0], ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL));
            //zk_1,zk_2
            List<String> list = zk.get().getChildren(LOCK_NAME, false);
            Collections.sort(list);
            System.out.println(list);
            String minNode = list.get(0);

            if ((LOCK_NAME + "/" + minNode).equals(CURRENT_NODE.get())) {
                return true;
            } else {
                //等待锁
                Integer currentIndex = list.indexOf(CURRENT_NODE.get().substring(CURRENT_NODE.get().lastIndexOf("/") + 1));
                String preNodeName = list.get(currentIndex - 1);

                //监听前一个节点删除事件
                final CountDownLatch countDownLatch = new CountDownLatch(1);
                zk.get().exists(LOCK_NAME + "/" + preNodeName, new Watcher() {
                    public void process(WatchedEvent watchedEvent) {
                        if (Event.EventType.NodeDeleted.equals(watchedEvent.getType())) {
                            countDownLatch.countDown();
                            System.out.println(Thread.currentThread().getName() + "唤醒锁..");
                        }
                    }
                });

                System.out.println(Thread.currentThread().getName() + "等待锁..");
                countDownLatch.await();//在变成0之前会一直阻塞

            }
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }

        return true;
    }

    public void unlock() {
        try {
            zk.get().delete(CURRENT_NODE.get(), -1);
            CURRENT_NODE.remove();
            zk.get().close();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (KeeperException e) {
            e.printStackTrace();
        }

    }

    public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
        return false;
    }

    public void lockInterruptibly() throws InterruptedException {

    }

    public Condition newCondition() {
        return null;
    }
}

二、基于Redis实现分布式锁

package lock;

import redis.clients.jedis.Jedis;

import java.util.Collections;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;

/**
 * Redis实现分布式锁
 */
public class RedisLock implements Lock {

    ThreadLocal<Jedis> jedis = new ThreadLocal<Jedis>();

    private static String LOCK_NAME = "LOCK";
    private static String REQUEST_ID = null;

    public RedisLock (String requestId) {
        RedisLock.REQUEST_ID = requestId;
        if (jedis.get() == null) {
            jedis.set(new Jedis("localhost"));
        }
    }
    public void lock() {
        if (tryLock()) {
            //jedis.set(LOCK_NAME, REQUEST_ID);
            //jedis.expire(LOCK_NAME, 1000);//设置过期时间

            //问题:上面两句代码不存在原子性操作,所以用下面一句代码替换掉
            jedis.get().set(LOCK_NAME, REQUEST_ID, "NX", "PX", 1000);
        }
    }

    public boolean tryLock() {
        while (true) {
            //key不存在返回1,不存在则返回0
            Long lock = jedis.get().setnx(LOCK_NAME, REQUEST_ID);
            if (lock == 1) {
                return true;
            }
        }
    }


    public void unlock() {
        //问题:保证不了原子性
        //String value = jedis.get(LOCK_NAME);
        //if (REQUEST_ID.equals(value)) {
        //    jedis.del(LOCK_NAME);
        //}

        String script = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end";
        jedis.get().eval(script, Collections.singletonList(LOCK_NAME), Collections.singletonList(REQUEST_ID));
        jedis.get().close();
        jedis.remove();

    }

    public Condition newCondition() {
        return null;
    }

    public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
        return false;
    }

    public void lockInterruptibly() throws InterruptedException {

    }
}

猜你喜欢

转载自www.cnblogs.com/wwzyy/p/10769207.html