使用easyopen拦截器防止表单重复提交

关于easyopen,请前往:码云了解。

在接口开发过程中,表单重复提交的情况会经常出现。比如做手机app开发,app端可能会连续触发两次请求,如果服务端不做处理,可能会有2次重复操作。

解决的方法也有多种:

第一种是使用token来解决,具体思路是:当用户访问视图时,由服务端生成一个Token放入session中,同时这个token跟随返回到视图页面,用js接收或者 hidden 放入要提交的表单中,当提交表单的时候 比较两个Token的值是否一致,再进行数据操作,并且再次改变Token中的值,当表单再次提交时 token中的值不一致,则不会执行相应方法了。

第二种可以用锁来处理,当用户请求进来后,对这个用户进行加锁处理,然后处理业务逻辑,只要业务逻辑没有处理完毕,该用户的其它线程请求进来始终被拒绝。

本文使用easyopen拦截器来实现第二种方式。

easyopen的拦截器使用方式同springmvc拦截器,其完整接口定义如下:

/**
 * 拦截器,原理同springmvc拦截器
 * @author tanghc
 *
 */
public interface ApiInterceptor {
    /**
     * 预处理回调方法,在方法调用前执行
     * @param request
     * @param response
     * @param serviceObj service类
     * @param argu 方法参数
     * @return
     * @throws Exception
     */
    boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object serviceObj, Object argu)
            throws Exception;

    /**
     * 接口方法执行完后调用此方法。
     * @param request
     * @param response
     * @param serviceObj service类
     * @param argu 参数
     * @param result 方法返回结果
     * @throws Exception
     */
    void postHandle(HttpServletRequest request, HttpServletResponse response, Object serviceObj, Object argu,
            Object result) throws Exception;

    /**
     * 结果包装完成后执行
     * @param request
     * @param response
     * @param serviceObj service类
     * @param argu 参数
     * @param result 最终结果,被包装过
     * @param e 
     * @throws Exception
     */
    void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object serviceObj, Object argu,
            Object result, Exception e) throws Exception;

    /**
     * 匹配拦截器
     * @param apiMeta 接口信息
     * @return
     */
    boolean match(ApiMeta apiMeta);
}

本次将要实现的需求如下:

  1. 当一个用户线程正在处理一个业务方法时,该用户的其它线程进来被拒绝
  2. 支持集群环境处理(单机程序可用synchronize解决,但不适合集群)

实现思路

  1. 使用redis做全局锁,在preHandle方法中申请锁
  2. afterCompletion方法中释放锁

代码实现

首先给出redis申明锁/释放锁工具类:

/**
<pre>
redis分布式锁
https://wudashan.cn/2017/10/23/Redis-Distributed-Lock-Implement/
</pre>
 * @author tanghc
 *
 */
public class RedisTool {

    private static final String LOCK_SUCCESS = "OK";
    private static final String SET_IF_NOT_EXIST = "NX";
    private static final String SET_WITH_EXPIRE_TIME = "PX";
    
    private static final Long RELEASE_SUCCESS = 1L;

    /**
     * 尝试获取分布式锁
     * @param jedis Redis客户端
     * @param lockKey 锁
     * @param requestId 请求标识
     * @param expireTimeMilliseconds 超期时间,多少毫秒后这把锁自动释放
     * @return 是否获取成功
     */
    public static boolean tryGetDistributedLock(Jedis jedis, String lockKey, String requestId, int expireTimeMilliseconds ) {
		String result = jedis.set(lockKey, requestId, SET_IF_NOT_EXIST, SET_WITH_EXPIRE_TIME, expireTimeMilliseconds);
		
		if (LOCK_SUCCESS.equals(result)) {
			return true;
		}
		return false;

    }

    

    /**
     * 释放分布式锁
     * @param jedis Redis客户端
     * @param lockKey 锁
     * @param requestId 请求标识
     * @return 是否释放成功
     */
    public static boolean releaseDistributedLock(Jedis jedis, String lockKey, String requestId) {
		String script = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end";
		Object result = jedis.eval(script, Collections.singletonList(lockKey), Collections.singletonList(requestId));
		
		if (RELEASE_SUCCESS.equals(result)) {
			return true;
		}
		return false;
    }
}

工具类的实现原理参考:Redis分布式锁的正确实现方式(Java版)

然后需要一个管理Jedis对象的工具类

@Component
public class JedisConfig {

	@Value("${spring.redis.database}")
	private String database;
	@Value("${spring.redis.host}")
	private String host;
	@Value("${spring.redis.password}")
	private String password;
	@Value("${spring.redis.port}")
	private String port;
	@Value("${spring.redis.timeout}")
	private String timeout;
	
	@Value("${spring.redis.pool.max-idle}")
	private String maxIdle;
	@Value("${spring.redis.pool.min-idle}")
	private String minIdle;
	@Value("${spring.redis.pool.max-active}")
	private String maxActive;
	@Value("${spring.redis.pool.max-wait}")
	private String maxWait;

	@Bean
	public JedisPool jedisPool() {
		JedisPoolConfig jedisPoolConfig = new JedisPoolConfig();
		jedisPoolConfig.setMaxIdle(NumberUtils.toInt(maxIdle,JedisPoolConfig.DEFAULT_MAX_IDLE));
		jedisPoolConfig.setMinIdle(NumberUtils.toInt(minIdle,JedisPoolConfig.DEFAULT_MIN_IDLE));
		jedisPoolConfig.setMaxTotal(NumberUtils.toInt(maxActive,JedisPoolConfig.DEFAULT_MAX_TOTAL));
		jedisPoolConfig.setMaxWaitMillis(NumberUtils.toLong(maxWait, JedisPoolConfig.DEFAULT_MAX_WAIT_MILLIS));
		jedisPoolConfig.setTestOnBorrow(true);
		jedisPoolConfig.setTestOnReturn(true);
		
		return new JedisPool(jedisPoolConfig, 
				host, 
				NumberUtils.toInt(port, 6379),  
				NumberUtils.toInt(timeout, 3000),
				password,
				NumberUtils.toInt(database, 0));
	}

}

这里使用spring依赖注入一个JedisPool对象。

最后是编写拦截器,首先拦截器的伪代码如下:

/**
<pre>
业务处理锁(防暴击):
同一个人同一时间只能处理一个业务。
</pre>
 * @author tanghc
 *
 */
public class LockInterceptor extends ApiInterceptorAdapter {

	private Logger logger = LoggerFactory.getLogger(getClass());

	// 拦截接口名当中有这些关键字的
    private static List<String> uriKeyList = Arrays.asList("order.cancel", "order.create");
    
    private int lockExpireMilliseconds = 3000; // 锁过期时间,3秒
    
	private JedisPool jedisPool;

	public LockInterceptor() {
                // 从spring容器中获取对象
		jedisPool = SpringContextUtils.getBean(JedisPool.class);
	}

	@Override
	public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object serviceObj, Object argu)
			throws Exception {
		Jedis jedis = jedisPool.getResource();		
		try {
			boolean hasLock = 申请redis锁
			if(!hasLock) { // 如果没有得到锁,说明重复提交
				response返回错误信息
				return false;
			}
		}finally {
			jedis.close(); // 最后别忘了关闭锁
		}
		
		return true;
	}

	@Override
	public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object serviceObj,
			Object argu, Object result, Exception e) throws Exception {			
		
		// 释放锁
                Jedis jedis = jedisPool.getResource();
		try {
			RedisTool.releaseDistributedLock(jedis, lockKey, requestId);
		} finally {
			jedis.close();
		}
	}
	
		
    @Override
    public boolean match(ApiMeta apiMeta) {
        String name = apiMeta.getName();
        return uriKeyList.contains(name); // 匹配接口,匹配到才执行该拦截器
    }

}

完整代码

理解了伪代码逻辑后,再来看下完整代码

/**
<pre>
业务处理锁(防暴击):
同一个人同一时间只能处理一个业务。
</pre>
 * @author tanghc
 *
 */
public class LockInterceptor extends ApiInterceptorAdapter {

	private Logger logger = LoggerFactory.getLogger(getClass());

	// 拦截接口名当中有这些关键字的
    private static List<String> uriKeyList = Arrays.asList("order.cancel", "order.create");
    
    private int lockExpireMilliseconds = 3000; // 锁过期时间,3秒
    
	private JedisPool jedisPool;

	public LockInterceptor() {
		jedisPool = SpringContextUtils.getBean(JedisPool.class);
	}

	@Override
	public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object serviceObj, Object argu)
			throws Exception {
		
		LoginUser loginUser = ApiUtil.getCurrentUser(); // 获取当前登录用户
		
		String lockKey = this.getLockKey(loginUser);
		String requestId = this.getRequestId(loginUser);
		Jedis jedis = jedisPool.getResource();
		
		try {
			boolean hasLock = RedisTool.tryGetDistributedLock(jedis, lockKey, requestId , lockExpireMilliseconds);
			// 如果没有取到锁,认为是暴击,直接返回
			if(!hasLock) {
				logger.warn("用户({},{})访问{}产生暴击!",loginUser.getId(),loginUser.getPhone(),ApiContext.getApiParam().fatchNameVersion());
				ApiResult result = new ApiResult();
				result.setCode(-102);
				result.setMsg("您已提交,请耐心等待哦");
				ResponseUtil.renderJson(response, JSON.toJSONString(result));
				return false;
			}
		}finally {
			jedis.close();
		}
		
		return true;
	}

	@Override
	public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object serviceObj,
			Object argu, Object result, Exception e) throws Exception {	
		
		LoginUser loginUser = ApiUtil.getCurrentUser();
		
		String lockKey = this.getLockKey(loginUser);
		String requestId = this.getRequestId(loginUser);
		Jedis jedis = jedisPool.getResource();
		try {
			RedisTool.releaseDistributedLock(jedis, lockKey, requestId);
		} finally {
			jedis.close();
		}
	}
	
	private String getLockKey(LoginUser loginUser) {
		return "api_lock_key:" + String.valueOf(loginUser.getId());
	}
	
	private String getRequestId(LoginUser loginUser) {
		return "api_lock_request_id_" + loginUser.getId();
	}
	
    @Override
    public boolean match(ApiMeta apiMeta) {
        String name = apiMeta.getName();
        return uriKeyList.contains(name);
    }

}

单元测试

最后给出单元测试代码

public class ApiTest extends TestCase {

	String url = "http://localhost:8080/api";
	String appId = "test";
	String secret = "123456";
	String token = "0094FC708C34490F949A9FAB90453195";

	/**
	 * 暴击测试,10条线程同时请求
	 * @throws InterruptedException
	 */
	@Test
	public void testLock() throws InterruptedException {
		int threadsCount = 10;
		final CountDownLatch countDownLatch = new CountDownLatch(1);
		final CountDownLatch count = new CountDownLatch(threadsCount);
		AtomicInteger successCount = new AtomicInteger();
		
		for (int i = 0; i < threadsCount; i++) {
			new Thread(new Runnable() {
				@Override
				public void run() {
					try {
						countDownLatch.await(); // 等在这里,执行countDownLatch.countDown();集体触发
						// 请求接口
						Map<String, Object> busiParam = new HashMap<>();
						String resp = doPost("order.create", busiParam);
						if("0".equals(JSON.parseObject(resp).getString("code"))) {
							successCount.incrementAndGet();
						}
						System.out.println(resp);
					} catch (Exception e) {
						e.printStackTrace();
					}finally {
						count.countDown();
					}
				}
			}).start();
		}
		countDownLatch.countDown();
		count.await();
		
		System.out.println("成功条数:" + successCount.get());
	}

	private String doPost(String name, Map<String, Object> busiParam) throws IOException {
		Map<String, String> param = new HashMap<String, String>();

		String json = JSON.toJSONString(busiParam);

		param.put(ParamNames.API_NAME, name);
		param.put(ParamNames.APP_KEY_NAME, appId);
		param.put(ParamNames.DATA_NAME, URLEncoder.encode(json, "UTF-8"));
		param.put(ParamNames.TIMESTAMP_NAME, getTime());
		param.put(ParamNames.VERSION_NAME, "");
		param.put(ParamNames.FORMAT_NAME, "json");
		param.put(ParamNames.ACCESS_TOKEN_NAME, token);

		String sign = buildSign(param, secret);

		param.put(ParamNames.SIGN_NAME, sign);

		System.out.println("请求内容:" + JSON.toJSONString(param));

		String resp = HttpUtil.post(url, param);

		return resp;
	}

	/**
	 * 构建签名
	 * 
	 * @param paramsMap
	 *            参数
	 * @param secret
	 *            密钥
	 * @return
	 * @throws IOException
	 */
	public static String buildSign(Map<String, ?> paramsMap, String secret) throws IOException {
		Set<String> keySet = paramsMap.keySet();
		List<String> paramNames = new ArrayList<String>(keySet);

		Collections.sort(paramNames);

		StringBuilder paramNameValue = new StringBuilder();

		for (String paramName : paramNames) {
			paramNameValue.append(paramName).append(paramsMap.get(paramName));
		}

		String source = secret + paramNameValue.toString() + secret;

		return md5(source);
	}

	/**
	 * 生成md5,全部大写
	 * 
	 * @param message
	 * @return
	 */
	public static String md5(String message) {
		try {
			// 1 创建一个提供信息摘要算法的对象,初始化为md5算法对象
			MessageDigest md = MessageDigest.getInstance("MD5");

			// 2 将消息变成byte数组
			byte[] input = message.getBytes();

			// 3 计算后获得字节数组,这就是那128位了
			byte[] buff = md.digest(input);

			// 4 把数组每一字节(一个字节占八位)换成16进制连成md5字符串
			return byte2hex(buff);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	/**
	 * 二进制转十六进制字符串
	 * 
	 * @param bytes
	 * @return
	 */
	private static String byte2hex(byte[] bytes) {
		StringBuilder sign = new StringBuilder();
		for (int i = 0; i < bytes.length; i++) {
			String hex = Integer.toHexString(bytes[i] & 0xFF);
			if (hex.length() == 1) {
				sign.append("0");
			}
			sign.append(hex.toUpperCase());
		}
		return sign.toString();
	}

	public String getTime() {
		return new SimpleDateFormat(ParamNames.TIMESTAMP_PATTERN).format(new Date());
	}

}

如果您有其它好的方法和建议,欢迎在评论中讨论。

猜你喜欢

转载自my.oschina.net/u/3658366/blog/1800004
今日推荐