Zookeeper分布式锁简单实现。

前言:无

代码1:分布式锁的实现类

package com.wj.demo.lock;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;

import org.apache.zookeeper.CreateMode;
import org.apache.zookeeper.KeeperException;
import org.apache.zookeeper.WatchedEvent;
import org.apache.zookeeper.Watcher;
import org.apache.zookeeper.Watcher.Event.KeeperState;
import org.apache.zookeeper.ZooDefs;
import org.apache.zookeeper.ZooDefs.Ids;
import org.apache.zookeeper.ZooKeeper;
import org.apache.zookeeper.data.Stat;

/**
 * 自定义的分布式锁
 */
public class ZkLock implements Lock,Watcher{
	
	/**不同名称锁前缀和序号之间的分割符号 */
	private static final String splitStr = "_xxxx_";
	/**超时时间 */
	private int sessionTimeout = 30000;
	/**异常集合 */
	private List<Exception> exception = new ArrayList<Exception>();
	/**闭锁  */
	private CountDownLatch latch = null;
	/**Zk对象  */
	private ZooKeeper zk;
	/**根节点 */
	private String root = "/my_root";
	/**锁名称(节点的前缀)*/
	private String lockName;
	/**当前锁名称(全路径)*/
	private String thisNode;
	/**上一个锁的名称(只含子路径)*/
	private String preNode;
	
	public ZkLock(String addr,String lockName){
		this.lockName=lockName;
		try {
			zk = new ZooKeeper(addr, this.sessionTimeout, this);
			latch = new CountDownLatch(1);
			latch.await();
			latch = null;
			Stat stat = zk.exists(root, false);
			if(stat==null){
				 zk.create(root, new byte[0], Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
			}
		}catch(Exception e){
			exception.add(e);
		}
	}

	@Override
	public void lock() {
		if(exception.size()!=0){
			throw new LockException(exception.get(0));
		}
		try{
			if (this.tryLock()){
				System.out.println("线程 " + Thread.currentThread().getId() + ": 节点" + thisNode + " 得到锁");
				return;
			} else {
				waitForLock(preNode, sessionTimeout);// 等待锁
			}
		}catch(KeeperException e){
			throw new LockException(e);
		}catch(InterruptedException e){
			throw new LockException(e);
		}
	}
	
	/**
	 * 当前等待锁。通过上一个节点的通知notify,或超时后重新尝试得到锁!
	 */
	private boolean waitForLock(String preNode, long waitTime)throws InterruptedException, KeeperException{
		Stat stat = zk.exists(root + "/" + preNode, true);
		if (stat != null) {
			System.out.println("线程 " + Thread.currentThread().getId() + ": 等待节点 " + root + "/" + preNode);
			this.latch = new CountDownLatch(1);
			this.latch.await(waitTime, TimeUnit.MILLISECONDS); //等待锁...
			this.latch = null;
		}
		return true;
	 }

	@Override
	public boolean tryLock() {
		try {
			if (lockName.contains(splitStr))throw new LockException("lockName can not contains ["+splitStr+"]");
			thisNode = zk.create(root + "/" + lockName + splitStr, new byte[0], ZooDefs.Ids.OPEN_ACL_UNSAFE,
					CreateMode.EPHEMERAL_SEQUENTIAL);
			System.out.println(thisNode + " is created ");
			List<String> subNodes = zk.getChildren(root, false);
			List<String> lockObjNodes = new ArrayList<String>(); //里面装的符合的节点名称是: lockName + splitStr + 000000000??
			for (String node : subNodes) {
				String _node = node.split(splitStr)[0];
				if (_node.equals(lockName)) {
					lockObjNodes.add(node);
				}
			}
			Collections.sort(lockObjNodes);
			if (thisNode.equals(root + "/" + lockObjNodes.get(0))) {
				return true;
			}
			String temp = thisNode.substring(thisNode.lastIndexOf("/") + 1);
			preNode = lockObjNodes.get(Collections.binarySearch(lockObjNodes, temp) - 1);
		} catch (KeeperException e) {
			throw new LockException(e);
		} catch (InterruptedException e) {
			throw new LockException(e);
		}
		return false;
	}

	@Override
	public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
		try {
			if (this.tryLock()) {
				return true;
			}
			return waitForLock(preNode, sessionTimeout);
		} catch (Exception e) {
			e.printStackTrace();
		}
		return false;
	}

	@Override
	public void unlock() {
		try {
			System.out.println("unlock " + thisNode);
			zk.delete(thisNode, -1);
			thisNode = null;
			zk.close();
		} catch (InterruptedException e) {
			e.printStackTrace();
		} catch (KeeperException e) {
			e.printStackTrace();
		}
	}

	@Override
	public void lockInterruptibly() throws InterruptedException {
		this.lock();
	}


	@Override
	public Condition newCondition() {
		return null;
	}
	
	/**1、latch在等待,并且zk已经连接时,才会得到通知正式干活。
	 * 2、当前线程通过latch等待,exists监控上一个节点释放锁时,就通过zk的通知通知当前线程。 
	 ***/
	@Override
	public void process(WatchedEvent event) {
		if(this.latch!=null&&event.getState()==KeeperState.SyncConnected){
			this.latch.countDown();
		}
	}
	
	public static class LockException extends RuntimeException{
		private static final long serialVersionUID = 1L;
		public LockException(String e) {
			super(e);
		}
		public LockException(Exception e) {
			super(e);
		}
	}
}	

代码2:并发测试辅助类

package com.wj.demo.lock;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * 并发测试工具类:传入任务参数后,创建线程执行,并统计执行时间。
 */
public class ConcurrentTest {
	
	/**装执行时间*/
	private CopyOnWriteArrayList<Long> list = new CopyOnWriteArrayList<Long>();
	
	/** 初始闭锁:创建的线程等待线程全部创建,然后当前线程发送计数信号放行执行。
	 */
	private CountDownLatch startSignal = new CountDownLatch(1);
	
	/** 结束闭锁:当前线程等待,启动的其他线程计数,计数完成,就统计时间。
	 */
	private CountDownLatch doneSignal = null;
	private AtomicInteger err = new AtomicInteger();
	private ConcurrentTask[] tasks = null;
	
	public ConcurrentTest(ConcurrentTask ... tasks){
		this.doneSignal = new CountDownLatch(tasks.length);
		this.tasks = tasks;
		if(tasks==null||tasks.length==0){
			throw new IllegalArgumentException("没有任务参数!");
		}
		start();
	}

	/**
	 * 启动线程:并且统计时间和错误。
	 */
	private void start() {
		this.createThread();
		this.startSignal.countDown();
		try {
			doneSignal.await();//详见字段注释。
		} catch (InterruptedException e) {
		}
		getExcuteTime();
	}

	/**
	 * 创建线程:创建的线程创建后阻塞,全部创建才开始执行。当前线程等待所有任务完成(doneSignal.countDown();)
	 */
	private void createThread(){
		for (int i = 0; i < tasks.length; i++) {
			final int j = i;
			new Thread(new Runnable(){
				@Override
				public void run() {
				    try {
						startSignal.await();
					    long start = System.currentTimeMillis();
						tasks[j].run();
						long end = System.currentTimeMillis();
						list.add(end-start); //添加单次task.run()执行时间。
					} catch (InterruptedException e) {
						err.getAndIncrement(); //err++
					}
				    doneSignal.countDown();
				}
			}).start();
		}
	}
	
	/**
	 * 计算执行时间
	 */
	private void getExcuteTime() {
		List<Long> temp = new ArrayList<Long>();
		temp.addAll(this.list);
		Collections.sort(temp);
		Long min = temp.get(0);
		Long max = temp.get(temp.size()-1);
		 long sum = 0L;
		for (Long t : temp) {
			sum += t;
		}
		long avg = sum/temp.size();
		System.out.println("min: " + min);
		System.out.println("max: " + max);
		System.out.println("avg: " + avg);
		System.out.println("err: " + err.get());
	}

	/**
	 * 任务接口
	 */
	public interface ConcurrentTask {
		void run();
	}

}

代码3:测试类

package com.wj.demo.lock;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;

import org.apache.commons.lang.StringUtils;
import org.apache.zookeeper.server.ZooKeeperServerMain;

import com.wj.demo.lock.ConcurrentTest.ConcurrentTask;

public class ZkTest {
	
	public static void main(String[] args) throws InterruptedException {
		
    	ZkServer.start(2181);
    	Thread.sleep(2000);
    	
    	//第一个任务去拿锁,干活3秒。
        Runnable task1 = new Runnable(){
            public void run() {
            	ZkLock lock = null;
                try {
                    lock = new ZkLock("127.0.0.1:2181","test");
                    lock.lock();
                    Thread.sleep(3000);
                    System.out.println("===Thread " + Thread.currentThread().getId() + " running");
                } catch (Exception e) {
                    e.printStackTrace();
                }
                finally {
                    if(lock != null)
                        lock.unlock();
                }
            }
             
        };
        new Thread(task1).start();
        
        //主线程,停顿1秒。
        try {
            Thread.sleep(1000);
        } catch (InterruptedException e1) {
            e1.printStackTrace();
        }
        
        //60个其他任务开启线程尝试去拿到锁干活。
        ConcurrentTask[] tasks = new ConcurrentTask[60];
        for(int i=0;i<tasks.length;i++){
            ConcurrentTask tempTask = new ConcurrentTask(){
                public void run() {
                	ZkLock lock = null;
                    try {
                        lock = new ZkLock("127.0.0.1:2181","test");
                        lock.lock();
                        System.out.println("Thread " + Thread.currentThread().getId() + " running");
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                    finally {
                        lock.unlock();
                    }
                }
            };
            tasks[i] = tempTask;
        }
        new ConcurrentTest(tasks); //执行60个任务在多线程环境中执行,并统计时间。
    }
	
	/**
	 * 内置方式的ZkServer
	 */
	public static class ZkServer extends ZooKeeperServerMain implements Runnable{
		
		final File confFile; // 配置文件

		public ZkServer(int clientPort) throws IOException {
			File tmpDir = createTmpDir();
			confFile = new File(tmpDir, "zoo.cfg");
			FileWriter fwriter = new FileWriter(confFile);
			fwriter.write("tickTime=2000\n");
			fwriter.write("initLimit=10\n");
			fwriter.write("syncLimit=5\n");

			File dataDir = new File(tmpDir, "data");
			if (!dataDir.mkdir()) {
				throw new IOException("unable to mkdir " + dataDir);
			}
			String df = StringUtils.replace(dataDir.toString(), "\\", "/");
			fwriter.write("dataDir=" + df + "\n");
			fwriter.write("clientPort=" + clientPort + "\n");
			fwriter.flush();
			fwriter.close();
		}

		public void run() {
			String args[] = new String[1];
	         args[0] = confFile.toString();
			try {
				initializeAndRun(args);
			} catch(Exception e) {
				e.printStackTrace();
			}
		}
		
		public static void start(int clientPort){
			try {
				new Thread(new ZkServer(clientPort)).start();
			} catch (IOException e) {
				e.printStackTrace();
			}
		}

		private static File createTmpDir() throws IOException {
			File basetemp = new File(System.getProperty("build.test.dir", "D:\\temp"));
			if (!basetemp.exists()) {
				basetemp.mkdir();
			}
			File tmpFile = File.createTempFile("test", ".junit", basetemp);
			File tmpDir = new File(tmpFile + ".dir");
			tmpDir.mkdirs();
			return tmpDir;
		}
	}
}
结语:百度多处博客连抄带写,简单实现的zk分布式锁。

猜你喜欢

转载自blog.csdn.net/shuixiou1/article/details/80617360