redisson多策略限流

目录

一、背景

二、实现代码

1.先定义限流策略枚举类

2.定义限流注解

3. 定义限流切面

4. 限流工具类


一、背景

日常开发过程中,限流场景非常多,同时限流策略也很多,诸如通过用户id限流,请求ip限流,请求URI限流。本文介绍redisson实现多策略限流,让你一劳永逸。

二、实现代码

1.先定义限流策略枚举类

public enum LimitType {

    /**
     * 自定义key
     */
    CUSTOM,

    /**
     * 请求者IP
     */
    IP,

    /**
     * 方法级别限流
     * key = ClassName+MethodName
     */
    METHOD,

    /**
     * 参数级别限流
     * key = ClassName+MethodName+Params
     */
    PARAMS,

    /**
     * 用户级别限流
     * key = ClassName+MethodName+Params+UserId
     */
    USER,

    /**
     * 根据request的uri限流
     * key = Request_uri
     */
    REQUEST_URI,

    /**
     * 对requesturi+userId限流
     * key = Request_uri+UserId
     */
    REQUESTURI_USERID,


    /**
     * 对userId限流
     * key = userId
     */
    SINGLEUSER,

    /**
     * 对方法限流
     * key = ClassName+MethodName
     */
    SINGLEMETHOD,

    /**
     * 对uri+params限流
     * key = uri+params
     */
    REQUEST_URI_PARAMS,

    /**
     * 对uri+params+userId限流
     * key = uri+params+userId
     */
    REQUEST_URI_PARAMS_USERID;

}

2.定义限流注解

@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface RedisLimit {


    String prefix() default "rateLimit:";

    /**
     * 限流唯一标识
     *
     * @return
     */
    String key() default "";

    /**
     * 限流单位时间(单位为s)
     *
     * @return
     */
    int time() default 1;

    /**
     * 单位时间内限制的访问次数
     *
     * @return
     */
    int count();

    /**
     * 限流类型
     *
     * @return
     */
    LimitType type() default LimitType.CUSTOM;

}

3. 定义限流切面

@Aspect
@Component
@Slf4j
public class RedisLimitAspect {
    @Autowired
    private RedissonClient redissonClient;

    @Autowired
    private ProceedingJoinPointUtil proceedingJoinPointUtil;

    @Pointcut("@annotation(com.demo.aop.RedisLimit)")
    private void pointCut() {
    }

    @Around("pointCut() && @annotation(redisLimit)")
    private Object around(ProceedingJoinPoint joinPoint, RedisLimit redisLimit) throws Exception {
        MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
        Method method = methodSignature.getMethod();
        Class<?> objClass = method.getReturnType();
        Object object = objClass.newInstance();
        Object result = new RspInfoBO();
        try {
            Object generateKey = proceedingJoinPointUtil.getKey(joinPoint, redisLimit);
            //redis key
            String key = redisLimit.prefix() + generateKey.toString();
            //声明一个限流器
            RRateLimiter rateLimiter = redissonClient.getRateLimiter(key);
            //设置速率,time秒中产生count个令牌
            rateLimiter.trySetRate(RateType.OVERALL, redisLimit.count(), redisLimit.time(), RateIntervalUnit.SECONDS);
            // 试图获取一个令牌,获取到返回true
            boolean tryAcquire = rateLimiter.tryAcquire();
            if (!tryAcquire) {
                throw new Exception(400, "操作频繁,请稍后再试!");
            }
            result = joinPoint.proceed();
        } catch (Throwable throwable) {
            log.info("请求参数:{}", Arrays.toString(joinPoint.getArgs()));
            return object;
        }
        return result;
    }
}

4. 限流工具类

@Component
public class ProceedingJoinPointUtil {
    @Autowired
    private HttpServletRequest request;

    private Map<LimitType, Function<ProceedingJoinPoint, String>> functionMap = new HashMap<>(9);

    @PostConstruct
    void initMap() {
        //初始化策略
        functionMap.put(LimitType.METHOD, this::getMethodTypeKey);
        functionMap.put(LimitType.PARAMS, this::getParamsTypeKey);
        functionMap.put(LimitType.USER, this::getUserTypeKey);
        functionMap.put(LimitType.REQUEST_URI, proceedingJoinPoint ->
                request.getRequestURI());
        functionMap.put(LimitType.REQUESTURI_USERID, proceedingJoinPoint ->
                request.getRequestURI() + getToken(request));
        functionMap.put(LimitType.REQUEST_URI_PARAMS, proceedingJoinPoint ->
                request.getRequestURI() + getParams(proceedingJoinPoint));
        functionMap.put(LimitType.REQUEST_URI_PARAMS_USERID, proceedingJoinPoint ->
                request.getRequestURI() + getParams(proceedingJoinPoint) + getToken(request));
        functionMap.put(LimitType.SINGLEUSER, (proceedingJoinPoint) ->
                String.valueOf(getToken(request)));
        functionMap.put(LimitType.SINGLEMETHOD, (proceedingJoinPoint -> {
            StringBuilder sb = new StringBuilder();
            appendMthodName(proceedingJoinPoint, sb);
            return sb.toString();
        }));
    }

    /**
     * 通过用户的token进行访问评率限制
     *
     * @param request
     * @return
     */
    private String getToken(HttpServletRequest request) {
        String v = request.getHeader("user_id");
        if (!StringUtils.isBlank(v)) {
            return v;
        }
        v = request.getHeader("user_id");
        if (!StringUtils.isBlank(v)) {
            return v;
        }
        v = request.getParameter("user_id");
        return v;
    }

    public Object getKey(ProceedingJoinPoint joinPoint, RedisLimit redisLimit) {
        //根据限制类型生成key
        Object generateKey = "";
        //自定义
        if (redisLimit.type() != LimitType.CUSTOM) {
            generateKey = generateKey(redisLimit.type(), joinPoint);
        } else {
            //非自定义
            generateKey = redisLimit.key();
        }
        return generateKey;
    }

    /**
     * 根据LimitType生成key
     *
     * @param type
     * @param joinPoint
     * @return
     */
    private Object generateKey(LimitType type, ProceedingJoinPoint joinPoint) {
        Function function = functionMap.get(type);
        Object result = function.apply(joinPoint);
        return result;
    }

    /**
     * 方法级别
     * key = ClassName+MethodName
     *
     * @param joinPoint
     * @return
     */
    private String getMethodTypeKey(ProceedingJoinPoint joinPoint) {
        StringBuilder sb = new StringBuilder();
        appendMthodName(joinPoint, sb);
        return sb.toString();
    }


    /**
     * 参数级别
     * key = ClassName+MethodName+Params
     *
     * @param joinPoint
     * @return
     */
    private String getParamsTypeKey(ProceedingJoinPoint joinPoint) {
        StringBuilder sb = new StringBuilder();
        appendMthodName(joinPoint, sb);
        appendParams(joinPoint, sb);
        return sb.toString();
    }


    /**
     * 用户级别
     * key = ClassName+MethodName+Params+UserId
     */
    private String getUserTypeKey(ProceedingJoinPoint joinPoint) {
        StringBuilder sb = new StringBuilder();
        appendMthodName(joinPoint, sb);
        appendParams(joinPoint, sb);
        //获取userId
        appendUserId(sb);
        return sb.toString();
    }


    /**
     * StringBuilder添加类名和方法名
     *
     * @param joinPoint
     * @param sb
     */
    private void appendMthodName(ProceedingJoinPoint joinPoint, StringBuilder sb) {
        Signature signature = joinPoint.getSignature();
        MethodSignature methodSignature = (MethodSignature) signature;
        Method method = methodSignature.getMethod();
        sb.append(joinPoint.getTarget().getClass().getName())//类名
                .append(method.getName());//方法名
    }

    /**
     * StringBuilder添加方法参数值
     *
     * @param joinPoint
     * @param sb
     */
    private void appendParams(ProceedingJoinPoint joinPoint, StringBuilder sb) {
        for (Object o : joinPoint.getArgs()) {
            sb.append(o.toString());
        }
    }

    private String getParams(ProceedingJoinPoint joinPoint) {
        StringBuilder sb = new StringBuilder();
        for (Object o : joinPoint.getArgs()) {
            if (o instanceof MultipartFile) {
                System.out.println("MultipartFile输入流");
            } else {
                sb.append(o.toString());
            }
        }
        return sb.toString();
    }

    /**
     * StringBuilder添加UserId
     *
     * @param sb
     */
    private void appendUserId(StringBuilder sb) {
        sb.append(getToken(request));
    }
}

如有帮助,请多多点赞关注,如有疑问,请留言私信,我会及时回复。

猜你喜欢

转载自blog.csdn.net/xrq1995/article/details/127621764