关于easyopen,请前往:码云了解。
在接口开发过程中,表单重复提交的情况会经常出现。比如做手机app开发,app端可能会连续触发两次请求,如果服务端不做处理,可能会有2次重复操作。
解决的方法也有多种:
第一种是使用token来解决,具体思路是:当用户访问视图时,由服务端生成一个Token放入session中,同时这个token跟随返回到视图页面,用js接收或者 hidden 放入要提交的表单中,当提交表单的时候 比较两个Token的值是否一致,再进行数据操作,并且再次改变Token中的值,当表单再次提交时 token中的值不一致,则不会执行相应方法了。
第二种可以用锁来处理,当用户请求进来后,对这个用户进行加锁处理,然后处理业务逻辑,只要业务逻辑没有处理完毕,该用户的其它线程请求进来始终被拒绝。
本文使用easyopen拦截器来实现第二种方式。
easyopen的拦截器使用方式同springmvc拦截器,其完整接口定义如下:
/**
* 拦截器,原理同springmvc拦截器
* @author tanghc
*
*/
public interface ApiInterceptor {
/**
* 预处理回调方法,在方法调用前执行
* @param request
* @param response
* @param serviceObj service类
* @param argu 方法参数
* @return
* @throws Exception
*/
boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object serviceObj, Object argu)
throws Exception;
/**
* 接口方法执行完后调用此方法。
* @param request
* @param response
* @param serviceObj service类
* @param argu 参数
* @param result 方法返回结果
* @throws Exception
*/
void postHandle(HttpServletRequest request, HttpServletResponse response, Object serviceObj, Object argu,
Object result) throws Exception;
/**
* 结果包装完成后执行
* @param request
* @param response
* @param serviceObj service类
* @param argu 参数
* @param result 最终结果,被包装过
* @param e
* @throws Exception
*/
void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object serviceObj, Object argu,
Object result, Exception e) throws Exception;
/**
* 匹配拦截器
* @param apiMeta 接口信息
* @return
*/
boolean match(ApiMeta apiMeta);
}
本次将要实现的需求如下:
- 当一个用户线程正在处理一个业务方法时,该用户的其它线程进来被拒绝
- 支持集群环境处理(单机程序可用synchronize解决,但不适合集群)
实现思路
- 使用redis做全局锁,在
preHandle
方法中申请锁 - 在
afterCompletion
方法中释放锁
代码实现
首先给出redis申明锁/释放锁工具类:
/**
<pre>
redis分布式锁
https://wudashan.cn/2017/10/23/Redis-Distributed-Lock-Implement/
</pre>
* @author tanghc
*
*/
public class RedisTool {
private static final String LOCK_SUCCESS = "OK";
private static final String SET_IF_NOT_EXIST = "NX";
private static final String SET_WITH_EXPIRE_TIME = "PX";
private static final Long RELEASE_SUCCESS = 1L;
/**
* 尝试获取分布式锁
* @param jedis Redis客户端
* @param lockKey 锁
* @param requestId 请求标识
* @param expireTimeMilliseconds 超期时间,多少毫秒后这把锁自动释放
* @return 是否获取成功
*/
public static boolean tryGetDistributedLock(Jedis jedis, String lockKey, String requestId, int expireTimeMilliseconds ) {
String result = jedis.set(lockKey, requestId, SET_IF_NOT_EXIST, SET_WITH_EXPIRE_TIME, expireTimeMilliseconds);
if (LOCK_SUCCESS.equals(result)) {
return true;
}
return false;
}
/**
* 释放分布式锁
* @param jedis Redis客户端
* @param lockKey 锁
* @param requestId 请求标识
* @return 是否释放成功
*/
public static boolean releaseDistributedLock(Jedis jedis, String lockKey, String requestId) {
String script = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end";
Object result = jedis.eval(script, Collections.singletonList(lockKey), Collections.singletonList(requestId));
if (RELEASE_SUCCESS.equals(result)) {
return true;
}
return false;
}
}
工具类的实现原理参考:Redis分布式锁的正确实现方式(Java版)
然后需要一个管理Jedis对象的工具类
@Component
public class JedisConfig {
@Value("${spring.redis.database}")
private String database;
@Value("${spring.redis.host}")
private String host;
@Value("${spring.redis.password}")
private String password;
@Value("${spring.redis.port}")
private String port;
@Value("${spring.redis.timeout}")
private String timeout;
@Value("${spring.redis.pool.max-idle}")
private String maxIdle;
@Value("${spring.redis.pool.min-idle}")
private String minIdle;
@Value("${spring.redis.pool.max-active}")
private String maxActive;
@Value("${spring.redis.pool.max-wait}")
private String maxWait;
@Bean
public JedisPool jedisPool() {
JedisPoolConfig jedisPoolConfig = new JedisPoolConfig();
jedisPoolConfig.setMaxIdle(NumberUtils.toInt(maxIdle,JedisPoolConfig.DEFAULT_MAX_IDLE));
jedisPoolConfig.setMinIdle(NumberUtils.toInt(minIdle,JedisPoolConfig.DEFAULT_MIN_IDLE));
jedisPoolConfig.setMaxTotal(NumberUtils.toInt(maxActive,JedisPoolConfig.DEFAULT_MAX_TOTAL));
jedisPoolConfig.setMaxWaitMillis(NumberUtils.toLong(maxWait, JedisPoolConfig.DEFAULT_MAX_WAIT_MILLIS));
jedisPoolConfig.setTestOnBorrow(true);
jedisPoolConfig.setTestOnReturn(true);
return new JedisPool(jedisPoolConfig,
host,
NumberUtils.toInt(port, 6379),
NumberUtils.toInt(timeout, 3000),
password,
NumberUtils.toInt(database, 0));
}
}
这里使用spring依赖注入一个JedisPool对象。
最后是编写拦截器,首先拦截器的伪代码如下:
/**
<pre>
业务处理锁(防暴击):
同一个人同一时间只能处理一个业务。
</pre>
* @author tanghc
*
*/
public class LockInterceptor extends ApiInterceptorAdapter {
private Logger logger = LoggerFactory.getLogger(getClass());
// 拦截接口名当中有这些关键字的
private static List<String> uriKeyList = Arrays.asList("order.cancel", "order.create");
private int lockExpireMilliseconds = 3000; // 锁过期时间,3秒
private JedisPool jedisPool;
public LockInterceptor() {
// 从spring容器中获取对象
jedisPool = SpringContextUtils.getBean(JedisPool.class);
}
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object serviceObj, Object argu)
throws Exception {
Jedis jedis = jedisPool.getResource();
try {
boolean hasLock = 申请redis锁
if(!hasLock) { // 如果没有得到锁,说明重复提交
response返回错误信息
return false;
}
}finally {
jedis.close(); // 最后别忘了关闭锁
}
return true;
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object serviceObj,
Object argu, Object result, Exception e) throws Exception {
// 释放锁
Jedis jedis = jedisPool.getResource();
try {
RedisTool.releaseDistributedLock(jedis, lockKey, requestId);
} finally {
jedis.close();
}
}
@Override
public boolean match(ApiMeta apiMeta) {
String name = apiMeta.getName();
return uriKeyList.contains(name); // 匹配接口,匹配到才执行该拦截器
}
}
完整代码
理解了伪代码逻辑后,再来看下完整代码
/**
<pre>
业务处理锁(防暴击):
同一个人同一时间只能处理一个业务。
</pre>
* @author tanghc
*
*/
public class LockInterceptor extends ApiInterceptorAdapter {
private Logger logger = LoggerFactory.getLogger(getClass());
// 拦截接口名当中有这些关键字的
private static List<String> uriKeyList = Arrays.asList("order.cancel", "order.create");
private int lockExpireMilliseconds = 3000; // 锁过期时间,3秒
private JedisPool jedisPool;
public LockInterceptor() {
jedisPool = SpringContextUtils.getBean(JedisPool.class);
}
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object serviceObj, Object argu)
throws Exception {
LoginUser loginUser = ApiUtil.getCurrentUser(); // 获取当前登录用户
String lockKey = this.getLockKey(loginUser);
String requestId = this.getRequestId(loginUser);
Jedis jedis = jedisPool.getResource();
try {
boolean hasLock = RedisTool.tryGetDistributedLock(jedis, lockKey, requestId , lockExpireMilliseconds);
// 如果没有取到锁,认为是暴击,直接返回
if(!hasLock) {
logger.warn("用户({},{})访问{}产生暴击!",loginUser.getId(),loginUser.getPhone(),ApiContext.getApiParam().fatchNameVersion());
ApiResult result = new ApiResult();
result.setCode(-102);
result.setMsg("您已提交,请耐心等待哦");
ResponseUtil.renderJson(response, JSON.toJSONString(result));
return false;
}
}finally {
jedis.close();
}
return true;
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object serviceObj,
Object argu, Object result, Exception e) throws Exception {
LoginUser loginUser = ApiUtil.getCurrentUser();
String lockKey = this.getLockKey(loginUser);
String requestId = this.getRequestId(loginUser);
Jedis jedis = jedisPool.getResource();
try {
RedisTool.releaseDistributedLock(jedis, lockKey, requestId);
} finally {
jedis.close();
}
}
private String getLockKey(LoginUser loginUser) {
return "api_lock_key:" + String.valueOf(loginUser.getId());
}
private String getRequestId(LoginUser loginUser) {
return "api_lock_request_id_" + loginUser.getId();
}
@Override
public boolean match(ApiMeta apiMeta) {
String name = apiMeta.getName();
return uriKeyList.contains(name);
}
}
单元测试
最后给出单元测试代码
public class ApiTest extends TestCase {
String url = "http://localhost:8080/api";
String appId = "test";
String secret = "123456";
String token = "0094FC708C34490F949A9FAB90453195";
/**
* 暴击测试,10条线程同时请求
* @throws InterruptedException
*/
@Test
public void testLock() throws InterruptedException {
int threadsCount = 10;
final CountDownLatch countDownLatch = new CountDownLatch(1);
final CountDownLatch count = new CountDownLatch(threadsCount);
AtomicInteger successCount = new AtomicInteger();
for (int i = 0; i < threadsCount; i++) {
new Thread(new Runnable() {
@Override
public void run() {
try {
countDownLatch.await(); // 等在这里,执行countDownLatch.countDown();集体触发
// 请求接口
Map<String, Object> busiParam = new HashMap<>();
String resp = doPost("order.create", busiParam);
if("0".equals(JSON.parseObject(resp).getString("code"))) {
successCount.incrementAndGet();
}
System.out.println(resp);
} catch (Exception e) {
e.printStackTrace();
}finally {
count.countDown();
}
}
}).start();
}
countDownLatch.countDown();
count.await();
System.out.println("成功条数:" + successCount.get());
}
private String doPost(String name, Map<String, Object> busiParam) throws IOException {
Map<String, String> param = new HashMap<String, String>();
String json = JSON.toJSONString(busiParam);
param.put(ParamNames.API_NAME, name);
param.put(ParamNames.APP_KEY_NAME, appId);
param.put(ParamNames.DATA_NAME, URLEncoder.encode(json, "UTF-8"));
param.put(ParamNames.TIMESTAMP_NAME, getTime());
param.put(ParamNames.VERSION_NAME, "");
param.put(ParamNames.FORMAT_NAME, "json");
param.put(ParamNames.ACCESS_TOKEN_NAME, token);
String sign = buildSign(param, secret);
param.put(ParamNames.SIGN_NAME, sign);
System.out.println("请求内容:" + JSON.toJSONString(param));
String resp = HttpUtil.post(url, param);
return resp;
}
/**
* 构建签名
*
* @param paramsMap
* 参数
* @param secret
* 密钥
* @return
* @throws IOException
*/
public static String buildSign(Map<String, ?> paramsMap, String secret) throws IOException {
Set<String> keySet = paramsMap.keySet();
List<String> paramNames = new ArrayList<String>(keySet);
Collections.sort(paramNames);
StringBuilder paramNameValue = new StringBuilder();
for (String paramName : paramNames) {
paramNameValue.append(paramName).append(paramsMap.get(paramName));
}
String source = secret + paramNameValue.toString() + secret;
return md5(source);
}
/**
* 生成md5,全部大写
*
* @param message
* @return
*/
public static String md5(String message) {
try {
// 1 创建一个提供信息摘要算法的对象,初始化为md5算法对象
MessageDigest md = MessageDigest.getInstance("MD5");
// 2 将消息变成byte数组
byte[] input = message.getBytes();
// 3 计算后获得字节数组,这就是那128位了
byte[] buff = md.digest(input);
// 4 把数组每一字节(一个字节占八位)换成16进制连成md5字符串
return byte2hex(buff);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* 二进制转十六进制字符串
*
* @param bytes
* @return
*/
private static String byte2hex(byte[] bytes) {
StringBuilder sign = new StringBuilder();
for (int i = 0; i < bytes.length; i++) {
String hex = Integer.toHexString(bytes[i] & 0xFF);
if (hex.length() == 1) {
sign.append("0");
}
sign.append(hex.toUpperCase());
}
return sign.toString();
}
public String getTime() {
return new SimpleDateFormat(ParamNames.TIMESTAMP_PATTERN).format(new Date());
}
}
如果您有其它好的方法和建议,欢迎在评论中讨论。