SpringBoot uses redis+lua to achieve current limit or repeated submission

 This article mainly demonstrates how the Spring Boot project uses AOP combined with Redis + Lua scripts to achieve distributed current limiting, which aims to protect the API from frequent malicious access;

  1. Redis configuration file

@Configuration
public class RedisConfig {
    @Bean
    public RedisScript<Long> limitRedisScript() {
        DefaultRedisScript redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/redis/limit.lua")));
        redisScript.setResultType(Long.class);
        return redisScript;
    }
}

2. Lua script

- Subscripts start from 1
local key = KEYS[1]
local now = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local expired = tonumber(ARGV[3])
- Maximum number of visits
local max = tonumber(ARGV[4])

- Clear outdated data
- Remove all elements in the specified score range, expired is the score that has expired
- According to the current time milliseconds-timeout milliseconds, get the expired time expired
redis.call('zremrangebyscore', key, 0, expired)

- Get the current number of elements in zset
local current = tonumber(redis.call('zcard', key))
local next = current + 1

if next > max then
  - When the current limit is reached, return 0
  return 0;
else
  - Add an element whose value and score are both the current timestamp to zset, [value,score]
  redis.call("zadd", key, now, now)
  - Reset the expiration time of zset for each visit, in milliseconds
  redis.call("pexpire", key, ttl)
  return next
end

 

3. Annotation

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
    long DEFAULT_REQUEST = 10;
    /**
     * max maximum number of requests
     */
    @AliasFor("max")
    long value() default DEFAULT_REQUEST;

    /**
     * max maximum number of requests
     */
    @AliasFor("value")
    long max() default DEFAULT_REQUEST;

    /**
     * Current limit key
     */
    String key() default "";

    /**
     * The timeout period is 1 minute by default
     */
    long timeout() default 1;

    /**
     * Timeout unit, the default minute
     */
    TimeUnit timeUnit() default TimeUnit.MINUTES;
}

4. Cut noodles

@Aspect
@Component
@RequiredArgsConstructor(onConstructor_ = @Autowired)
@ Slf4j
public class LimiteAspect {
    private final static String SEPARATOR = ":";
    private final static String REDIS_LIMIT_KEY_PREFIX = "limit:";
    private final StringRedisTemplate stringRedisTemplate;
    private final RedisScript<Long> limitRedisScript;

    @Pointcut("@annotation(com.juejueguai.springbootdemoratelimitredis.annotation.RateLimiter)")
    public void rateLimit () {

    }

    @Around("rateLimit()")
    public Object pointcut(ProceedingJoinPoint point) throws Throwable {
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        // Get the RateLimiter annotation through AnnotationUtils.findAnnotation
        RateLimiter rateLimiter = AnnotationUtils.findAnnotation (method, RateLimiter.class);
        if (rateLimiter != null) {
            String key = rateLimiter.key ();
            // By default, use class name + method name as the key prefix for current limiting
            if (StrUtil.isBlank(key)) {
                key = method.getDeclaringClass().getName()+StrUtil.DOT+method.getName();
            }
            // The key of the final current limit is the prefix + IP address
            // TODO: At this time, it is necessary to consider the situation of multi-user access in the LAN, so it is more reasonable to add method parameters to the key later
            key = key + SEPARATOR + IpUtil.getIpAddr();

            long max = rateLimiter.max ();
            long timeout = rateLimiter.timeout();
            TimeUnit timeUnit = rateLimiter.timeUnit ();
            boolean limited = shouldLimited(key, max, timeout, timeUnit);
            if (limited) {
                throw new RuntimeException("Hand speed is too fast, go slower~");
            }
        }

        return point.proceed();
    }

    private boolean shouldLimited(String key, long max, long timeout, TimeUnit timeUnit) {
        // The final key format is:
        // limit: custom key: IP
        // limit: class name. method name: IP
        key = REDIS_LIMIT_KEY_PREFIX + key;
        // Use the unit milliseconds uniformly
        long ttl = timeUnit.toMillis(timeout);
        // The current time in milliseconds
        long now = Instant.now().toEpochMilli();
        long expired = now - ttl;
        // Note that this must be converted to String, otherwise it will report an error java.lang.Long cannot be cast to java.lang.String
        Long executeTimes = stringRedisTemplate.execute(limitRedisScript, Collections.singletonList(key), now + "", ttl + "", expired + "", max + "");
        if (executeTimes != null) {
            if (executeTimes == 0) {
                log.error("[{}] The access limit has been reached within {} milliseconds per unit time, the current interface limit is {}", key, ttl, max);
                return true;
            } else {
                log.info("[{}] visit {} times in unit time {} milliseconds", key, ttl, executeTimes);
                return false;
            }
        }
        return false;
    }
}

5. Global exception

@ Slf4j
@RestControllerAdvice
public class GlobalExceptionHandler {

    @ExceptionHandler(RuntimeException.class)
    public Dict handler(RuntimeException ex) {
        return Dict.create().set("msg", ex.getMessage());
    }
}

6, ip tools

@ Slf4j
public class IpUtil {
    private final static String UNKNOWN = "unknown";
    private final static int MAX_LENGTH = 15;

    /**
     * Get IP address
     * Using reverse proxy software such as Nginx, you cannot obtain the IP address through request.getRemoteAddr()
     * If a multi-level reverse proxy is used, the value of X-Forwarded-For is not only one, but a string of IP addresses. The first non-unknown valid IP string in X-Forwarded-For is the real IP address
     */
    public static String getIpAddr() {
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        String ip = null;
        try {
            ip = request.getHeader("x-forwarded-for");
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("Proxy-Client-IP");
            }
            if (StrUtil.isEmpty(ip) || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("WL-Proxy-Client-IP");
            }
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("HTTP_CLIENT_IP");
            }
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("HTTP_X_FORWARDED_FOR");
            }
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getRemoteAddr();
            }
        } catch (Exception e) {
            log.error("IPUtils ERROR ", e);
        }
        // Use a proxy, get the first IP address
        if (!StrUtil.isEmpty(ip) && ip.length() > MAX_LENGTH) {
            if (ip.indexOf(StrUtil.COMMA) > 0) {
                ip = ip.substring(0, ip.indexOf(StrUtil.COMMA));
            }
        }
        return ip;
    }
}

 

7. Interface test
@RestController
@RequestMapping
@ Slf4j
public class TestController {


    @RateLimiter (value = 5)
    @GetMapping("/test1")
    public Dict test1() {
        log.info("[test1] was executed...");
        return Dict.create().set("msg", "hello,world!").set("description", "Don't want to see me all the time, if you don't believe me, refresh it quickly~");
    }

    @GetMapping("/test2")
    public Dict test2() {
        log.info("[test2] was executed...");
        return Dict.create().set("msg", "hello,world!").set("description", "I've been there all the time, I'm here to stay away");
    }

    @RateLimiter(value = 2, key = "Test custom key")
    @GetMapping("/test3")
    public Dict test3() {
        log.info("[test3] was executed...");
        return Dict.create().set("msg", "hello,world!").set("description", "Don't want to see me all the time, if you don't believe me, refresh it quickly~");
    }
}

8. JMeter, a stress testing tool

Guess you like

Origin blog.csdn.net/A___B___C/article/details/108275800