This article mainly demonstrates how the Spring Boot project uses AOP combined with Redis + Lua scripts to achieve distributed current limiting, which aims to protect the API from frequent malicious access;
1. Redis configuration file
@Configuration public class RedisConfig { @Bean public RedisScript<Long> limitRedisScript() { DefaultRedisScript redisScript = new DefaultRedisScript<>(); redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/redis/limit.lua"))); redisScript.setResultType(Long.class); return redisScript; } }
2. Lua script
- Subscripts start from 1 local key = KEYS[1] local now = tonumber(ARGV[1]) local ttl = tonumber(ARGV[2]) local expired = tonumber(ARGV[3]) - Maximum number of visits local max = tonumber(ARGV[4]) - Clear outdated data - Remove all elements in the specified score range, expired is the score that has expired - According to the current time milliseconds-timeout milliseconds, get the expired time expired redis.call('zremrangebyscore', key, 0, expired) - Get the current number of elements in zset local current = tonumber(redis.call('zcard', key)) local next = current + 1 if next > max then - When the current limit is reached, return 0 return 0; else - Add an element whose value and score are both the current timestamp to zset, [value,score] redis.call("zadd", key, now, now) - Reset the expiration time of zset for each visit, in milliseconds redis.call("pexpire", key, ttl) return next end
3. Annotation
@Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface RateLimiter { long DEFAULT_REQUEST = 10; /** * max maximum number of requests */ @AliasFor("max") long value() default DEFAULT_REQUEST; /** * max maximum number of requests */ @AliasFor("value") long max() default DEFAULT_REQUEST; /** * Current limit key */ String key() default ""; /** * The timeout period is 1 minute by default */ long timeout() default 1; /** * Timeout unit, the default minute */ TimeUnit timeUnit() default TimeUnit.MINUTES; }
4. Cut noodles
@Aspect @Component @RequiredArgsConstructor(onConstructor_ = @Autowired) @ Slf4j public class LimiteAspect { private final static String SEPARATOR = ":"; private final static String REDIS_LIMIT_KEY_PREFIX = "limit:"; private final StringRedisTemplate stringRedisTemplate; private final RedisScript<Long> limitRedisScript; @Pointcut("@annotation(com.juejueguai.springbootdemoratelimitredis.annotation.RateLimiter)") public void rateLimit () { } @Around("rateLimit()") public Object pointcut(ProceedingJoinPoint point) throws Throwable { MethodSignature signature = (MethodSignature) point.getSignature(); Method method = signature.getMethod(); // Get the RateLimiter annotation through AnnotationUtils.findAnnotation RateLimiter rateLimiter = AnnotationUtils.findAnnotation (method, RateLimiter.class); if (rateLimiter != null) { String key = rateLimiter.key (); // By default, use class name + method name as the key prefix for current limiting if (StrUtil.isBlank(key)) { key = method.getDeclaringClass().getName()+StrUtil.DOT+method.getName(); } // The key of the final current limit is the prefix + IP address // TODO: At this time, it is necessary to consider the situation of multi-user access in the LAN, so it is more reasonable to add method parameters to the key later key = key + SEPARATOR + IpUtil.getIpAddr(); long max = rateLimiter.max (); long timeout = rateLimiter.timeout(); TimeUnit timeUnit = rateLimiter.timeUnit (); boolean limited = shouldLimited(key, max, timeout, timeUnit); if (limited) { throw new RuntimeException("Hand speed is too fast, go slower~"); } } return point.proceed(); } private boolean shouldLimited(String key, long max, long timeout, TimeUnit timeUnit) { // The final key format is: // limit: custom key: IP // limit: class name. method name: IP key = REDIS_LIMIT_KEY_PREFIX + key; // Use the unit milliseconds uniformly long ttl = timeUnit.toMillis(timeout); // The current time in milliseconds long now = Instant.now().toEpochMilli(); long expired = now - ttl; // Note that this must be converted to String, otherwise it will report an error java.lang.Long cannot be cast to java.lang.String Long executeTimes = stringRedisTemplate.execute(limitRedisScript, Collections.singletonList(key), now + "", ttl + "", expired + "", max + ""); if (executeTimes != null) { if (executeTimes == 0) { log.error("[{}] The access limit has been reached within {} milliseconds per unit time, the current interface limit is {}", key, ttl, max); return true; } else { log.info("[{}] visit {} times in unit time {} milliseconds", key, ttl, executeTimes); return false; } } return false; } }
5. Global exception
@ Slf4j @RestControllerAdvice public class GlobalExceptionHandler { @ExceptionHandler(RuntimeException.class) public Dict handler(RuntimeException ex) { return Dict.create().set("msg", ex.getMessage()); } }
6, ip tools
@ Slf4j public class IpUtil { private final static String UNKNOWN = "unknown"; private final static int MAX_LENGTH = 15; /** * Get IP address * Using reverse proxy software such as Nginx, you cannot obtain the IP address through request.getRemoteAddr() * If a multi-level reverse proxy is used, the value of X-Forwarded-For is not only one, but a string of IP addresses. The first non-unknown valid IP string in X-Forwarded-For is the real IP address */ public static String getIpAddr() { HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest(); String ip = null; try { ip = request.getHeader("x-forwarded-for"); if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) { ip = request.getHeader("Proxy-Client-IP"); } if (StrUtil.isEmpty(ip) || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) { ip = request.getHeader("WL-Proxy-Client-IP"); } if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) { ip = request.getHeader("HTTP_CLIENT_IP"); } if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) { ip = request.getHeader("HTTP_X_FORWARDED_FOR"); } if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) { ip = request.getRemoteAddr(); } } catch (Exception e) { log.error("IPUtils ERROR ", e); } // Use a proxy, get the first IP address if (!StrUtil.isEmpty(ip) && ip.length() > MAX_LENGTH) { if (ip.indexOf(StrUtil.COMMA) > 0) { ip = ip.substring(0, ip.indexOf(StrUtil.COMMA)); } } return ip; } }
7. Interface test @RestController @RequestMapping @ Slf4j public class TestController { @RateLimiter (value = 5) @GetMapping("/test1") public Dict test1() { log.info("[test1] was executed..."); return Dict.create().set("msg", "hello,world!").set("description", "Don't want to see me all the time, if you don't believe me, refresh it quickly~"); } @GetMapping("/test2") public Dict test2() { log.info("[test2] was executed..."); return Dict.create().set("msg", "hello,world!").set("description", "I've been there all the time, I'm here to stay away"); } @RateLimiter(value = 2, key = "Test custom key") @GetMapping("/test3") public Dict test3() { log.info("[test3] was executed..."); return Dict.create().set("msg", "hello,world!").set("description", "Don't want to see me all the time, if you don't believe me, refresh it quickly~"); } }
8. JMeter, a stress testing tool