springboot中使用lua脚本+aop作限流访问案例代码


1.限流注解

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Limit {
    
    

    // 资源名称,用于描述接口功能
    String name() default "";

    // 资源 key
    String key() default "";

    // key 前缀
    String prefix() default "";

    // 时间,单位秒
    int period();

    // 限制访问次数
    int count();

    // 限制类型
    LimitType limitType() default LimitType.CUSTOMER;

}

2.redis配置

spring:
  redis:
    #数据库索引
    database: 0
    host: ....
    port: 6379
    password:
    jedis:
      pool:
        max-active: 8 # 连接池最大连接数(使用负值表示没有限制)
        max-wait: -1ms # 连接池最大阻塞等待时间(使用负值表示没有限制)
        max-idle: 8 # 连接池中的最大空闲连接
        min-idle: 0 # 连接池中的最小空闲连接
    #连接超时时间
    timeout: 5000
@Slf4j
@Configuration
@EnableCaching
@ConditionalOnClass(RedisOperations.class)
//Spring工程中引用了redis相关的包 才会构建这个bean
@EnableConfigurationProperties(RedisProperties.class)
public class RedisConfig extends CachingConfigurerSupport {
    
    

    /**
     *  设置 redis 数据默认过期时间,默认2小时
     *  设置@cacheable 序列化方式
     */
    @Bean
    public RedisCacheConfiguration redisCacheConfiguration() {
    
    
        FastJsonRedisSerializer<Object> fastJsonRedisSerializer = new FastJsonRedisSerializer<>(Object.class);
        RedisCacheConfiguration configuration = RedisCacheConfiguration.defaultCacheConfig();
        configuration = configuration.serializeValuesWith(RedisSerializationContext.SerializationPair.fromSerializer(fastJsonRedisSerializer)).entryTtl(Duration.ofHours(2));
        return configuration;
    }

    @SuppressWarnings("all")
    @Bean(name = "redisTemplate")
    @ConditionalOnMissingBean(name = "redisTemplate")
    public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
    
    
        RedisTemplate<Object, Object> template = new RedisTemplate<>();
        //序列化
        FastJsonRedisSerializer<Object> fastJsonRedisSerializer = new FastJsonRedisSerializer<>(Object.class);
        // value值的序列化采用fastJsonRedisSerializer
        template.setValueSerializer(fastJsonRedisSerializer);
        template.setHashValueSerializer(fastJsonRedisSerializer);
        // 全局开启AutoType,这里方便开发,使用全局的方式
        ParserConfig.getGlobalInstance().setAutoTypeSupport(true);
        // 建议使用这种方式,小范围指定白名单
        // ParserConfig.getGlobalInstance().addAccept("me.zhengjie.domain");
        // key的序列化采用StringRedisSerializer
        template.setKeySerializer(new StringRedisSerializer());
        template.setHashKeySerializer(new StringRedisSerializer());
        template.setConnectionFactory(redisConnectionFactory);
        return template;
    }

    /**
     * 自定义缓存key生成策略,默认将使用该策略
     */
    @Bean
    @Override
    public KeyGenerator keyGenerator() {
    
    
        return (target, method, params) -> {
    
    
            Map<String, Object> container = new HashMap<>(3);
            Class<?> targetClassClass = target.getClass();
            // 类地址
            container.put("class", targetClassClass.toGenericString());
            // 方法名称
            container.put("methodName", method.getName());
            // 包名称
            container.put("package", targetClassClass.getPackage());
            // 参数列表
            for (int i = 0; i < params.length; i++) {
    
    
                container.put(String.valueOf(i), params[i]);
            }
            // 转为JSON字符串
            String jsonString = JSON.toJSONString(container);
            // 做SHA256 Hash计算,得到一个SHA256摘要作为Key
            return DigestUtils.sha256Hex(jsonString);
        };
    }

    @Bean
    @Override
    public CacheErrorHandler errorHandler() {
    
    
        // 异常处理,当Redis发生异常时,打印日志,但是程序正常走
        log.info("初始化 -> [{}]", "Redis CacheErrorHandler");
        return new CacheErrorHandler() {
    
    
            @Override
            public void handleCacheGetError(RuntimeException e, Cache cache, Object key) {
    
    
                log.error("Redis occur handleCacheGetError:key -> [{}]", key, e);
            }

            @Override
            public void handleCachePutError(RuntimeException e, Cache cache, Object key, Object value) {
    
    
                log.error("Redis occur handleCachePutError:key -> [{}];value -> [{}]", key, value, e);
            }

            @Override
            public void handleCacheEvictError(RuntimeException e, Cache cache, Object key) {
    
    
                log.error("Redis occur handleCacheEvictError:key -> [{}]", key, e);
            }

            @Override
            public void handleCacheClearError(RuntimeException e, Cache cache) {
    
    
                log.error("Redis occur handleCacheClearError:", e);
            }
        };
    }

}

/**
 * Value 序列化
 *
 * @author /
 * @param <T>
 */
class FastJsonRedisSerializer<T> implements RedisSerializer<T> {
    
    

    private final Class<T> clazz;

    FastJsonRedisSerializer(Class<T> clazz) {
    
    
        super();
        this.clazz = clazz;
    }

    @Override
    public byte[] serialize(T t) {
    
    
        if (t == null) {
    
    
            return new byte[0];
        }
        return JSON.toJSONString(t, SerializerFeature.WriteClassName).getBytes(StandardCharsets.UTF_8);
    }

    @Override
    public T deserialize(byte[] bytes) {
    
    
        if (bytes == null || bytes.length <= 0) {
    
    
            return null;
        }
        String str = new String(bytes, StandardCharsets.UTF_8);
        return JSON.parseObject(str, clazz);
    }

}

/**
 * 重写序列化器
 *
 * @author /
 */
class StringRedisSerializer implements RedisSerializer<Object> {
    
    

    private final Charset charset;

    StringRedisSerializer() {
    
    
        this(StandardCharsets.UTF_8);
    }

    private StringRedisSerializer(Charset charset) {
    
    
        Assert.notNull(charset, "Charset must not be null!");
        this.charset = charset;
    }

    @Override
    public String deserialize(byte[] bytes) {
    
    
        return (bytes == null ? null : new String(bytes, charset));
    }

    @Override
    public byte[] serialize(Object object) {
    
    
        String string = JSON.toJSONString(object);
        if (StringUtils.isBlank(string)) {
    
    
            return null;
        }
        string = string.replace("\"", "");
        return string.getBytes(charset);
    }
}

3.aop配置

@Aspect
@Component
public class LimitAspect {
    
    

    private final RedisTemplate<Object, Object> redisTemplate;
    private static final Logger logger = LoggerFactory.getLogger(LimitAspect.class);

    public LimitAspect(RedisTemplate<Object, Object> redisTemplate) {
    
    
        this.redisTemplate = redisTemplate;
    }

    @Pointcut("@annotation(co.yixiang.annotation.Limit)")
    public void pointcut() {
    
    
    }

    @Around("pointcut()")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
    
    
        HttpServletRequest request = RequestHolder.getHttpServletRequest();
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method signatureMethod = signature.getMethod();
        Limit limit = signatureMethod.getAnnotation(Limit.class);
        LimitType limitType = limit.limitType();
        String key = limit.key();
        if (StringUtils.isEmpty(key)) {
    
    
            if (limitType == LimitType.IP) {
    
    
                key = StringUtils.getIp(request);
            } else {
    
    
                key = signatureMethod.getName();
            }
        }
//ImmutableList是一个不可变、线程安全的列表集合,它只会获取传入对象的一个副本,而不会影响到原来的变量或者对象
        // //获取一个有两个元素的不可变集合对象
        //        ImmutableList<String> list3 = ImmutableList .<String>of("12","23");
        //  // limit.prefix():key prefix
        ImmutableList<Object> keys = ImmutableList.of(StringUtils.join(limit.prefix(), "_", key, "_", request.getRequestURI().replaceAll("/", "_")));

        String luaScript = buildLuaScript();
        RedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class);
        Number count = redisTemplate.execute(redisScript, keys, limit.count(), limit.period());
        if (null != count && count.intValue() <= limit.count()) {
    
    
            logger.info("第{}次访问key为 {},描述为 [{}] 的接口", count, keys, limit.name());
            //count, keys, limit.name()为参数
            return joinPoint.proceed();
        } else {
    
    
            throw new BadRequestException("访问次数受限制");
        }
    }

    /**
     * 限流脚本
     */
    private String buildLuaScript() {
    
    
        return "local c" +
                "\nc = redis.call('get',KEYS[1])" +
                "\nif c and tonumber(c) > tonumber(ARGV[1]) then" +
                "\nreturn c;" +
                "\nend" +
                "\nc = redis.call('incr',KEYS[1])" +
                "\nif tonumber(c) == 1 then" +
                "\nredis.call('expire',KEYS[1],ARGV[2])" +
                "\nend" +
                "\nreturn c;";
    }
}

4.controller层测试

/**
 * @author /
 * 接口限流测试类
 */
@RestController
@RequestMapping("/api/limit")
@Api(tags = "系统:限流测试管理")
public class LimitController {
    
    

    private static final AtomicInteger ATOMIC_INTEGER = new AtomicInteger();

    /**
     * 测试限流注解,下面配置说明该接口 60秒内最多只能访问 10次,保存到redis的键名为 limit_test,
     */
    @GetMapping
    @AnonymousAccess
    @ApiOperation("测试")
    @Limit(key = "test", period = 60, count = 10, name = "testLimit", prefix = "limit")
    public int testLimit() {
    
    
        return ATOMIC_INTEGER.incrementAndGet();
    }
}

拓展:Atomic类的学习

JUC包提供了一系列的原子性操作类,这些类都是使用非阻塞算法CAS实现的,相比使用锁实现原子性操作这在性能上有很大提高。

JUC并发包中包含有AtomicInteger、AtomicLong和AtomicBoolean等原子性操作类,它们的原理类似。AtomicLong是原子性递增或者递减类,其内部使用Unsafe来实现,我们看下面的代码。

public class AtomicLong extends Number implements java.io.Serializable {
    
    
    private static final long serialVersionUID = 1927816293512124184L;

   //判断jvm是否支持long类型无锁CAS
    static final boolean VM_SUPPORTS_LONG_CAS = VMSupportsCS8();

    /**
     * Returns whether underlying JVM supports lockless CompareAndSet
     * for longs. Called only once and cached in VM_SUPPORTS_LONG_CAS.
     */
    private static native boolean VMSupportsCS8();

    private static final Unsafe U = Unsafe.getUnsafe();
    private static final long VALUE
        = U.objectFieldOffset(AtomicLong.class, "value");

    private volatile long value;//实际变量值
    //value被声明为volatile的,这是为了在多线程下保证内存可见性,value是具体存放计数的变量。

    ....
    }

注意:private static final Unsafe U = Unsafe.getUnsafe();为何能通过Unsafe.getUnsafe()方法获取到Unsafe类的实例?其实这是因为AtomicLong类也是在rt.jar包下面的,AtomicLong类就是通过BootStarp类加载器进行加载的。
jdk8中的getAndAddLong方法:

public final getAndAddLong(Object paramObject,long paramLong1,long paramlong2)
{
    
    
long l;
do{
    
    
l = getLongvolatile(paramObject,paramLong);
}while(!compareAndSwapLong(paramObject,paramLong1,1,1 + paramLong2));
return l;
}

可以看到,JDK 7的AtomicLong中的循环逻辑已经被JDK 8中的原子操作类UNsafe内置了,之所以内置应该是考虑到这个函数在其他地方也会用到,而内置可以提高复用性。
下面通过一个多线程使用AtomicLong统计0的个数的例子来加深对AtomicLong的理解。

/**
统计0的个数
*/
public class Atomic
{
    
    
//(10)创建Long型原子计数器
private static AtomicLong atomicLong=new AtomicLong();
//(11)创建数据源
private static Integer[]arrayone=new Integer[]{
    
    01230560560};
private static Integer[]arrayTwo=new Integer[]{
    
    101230560560};
public static void main(String[]args)throws InterruptedException
{
    
    
//(12)线程one统计数组array0ne中0的个数
Thread threadOne=new Threadnew Runnable(){
    
    
@Override
public void run(){
    
    
int size=arrayone.length;
for(inti=0;i<size;++i){
    
    
if(arrayone[i].intValue()==0{
    
    
atomicLong.incrementAndGet();
}
}}
});
//(13)线程two统计数组arrayTwo中0的个数
Thread threadTwo=new Threadnew Runnable(){
    
    
@override
public void run(){
    
    
int size=arrayTwo.length;
for(inti=0;i<size;++i){
    
    
if(arrayTwo[i].intValue()=0{
    
    
atomicLong.incrementAndGet();
}
}
}
});
//(14)启动子线程
threadone.start();
threadTwo.start();
//(15)等待线程执行完毕
threadone.join();
threadTwo.join();
System.out.println("count0:"+atomicLong.get());
}

输出count0:7
如上代码中的两个线程各自统计自己所持数据中0的个数,每当找到一个0就会调用AtomicLong的原子性递增方法。 在没有原子类的情况下,实现计数器需要使用一定的同步措施,比如使用synchronized关键字等,但是这些都是阻塞算法,对性能有一定损耗,而这些原子操作类都使用CAS非阻塞算法,性能更好。但是在高并发情况下AtomicLong还会存在性能问题。实际上JDK 8提供了一个在高并发下性能更好的LongAdder类

lua脚本学习

关于lua脚本参考了《redis入门指南》
使用脚本的好处如下:

1.减少网络开销:本来5次网络请求的操作,可以用一个请求完成,原先5次请求的逻辑放在redis服务器上完成。使用脚本,减少了网络往返时延。
2.原子操作:Redis会将整个脚本作为一个整体执行,中间不会被其他命令插入。
3.复用:客户端发送的脚本会永久存储在Redis中,意味着其他客户端可以复用这一脚本而不需要使用代码完成同样的逻辑。

示例:
写一个.lua脚本

local times = redis.call('incr',KEYS[1])

if times == 1 then
    redis.call('expire',KEYS[1], ARGV[1])
end

if times > tonumber(ARGV[2]) then
    return 0
end
return 1

在redis客户端机器上,如何测试这个脚本呢?如下:

redis-cli --eval ratelimiting.lua rate.limitingl:127.0.0.1 , 10 3
–eval参数是告诉redis-cli读取并运行后面的Lua脚本,ratelimiting.lua是脚本的位置,后面跟着是传给Lua脚本的参数。其中",“前的rate.limiting:127.0.0.1是要操作的键,可以再脚本中用KEYS[1]获取,”,“后面的10和3是参数,在脚本中能够使用ARGV[1]和ARGV[2]获得。注:”,"两边的空格不能省略,否则会出错

结合脚本的内容可知这行命令的作用是将访问频率限制为每10秒最多3次,所以在终端中不断的运行此命令会发现当访问频率在10秒内小于或等于3次时返回1,否则返回0。

猜你喜欢

转载自blog.csdn.net/qq_41358574/article/details/120888216