java基于map的一个简单限流-代码

版权声明:本文为wcuu原创文章。 https://blog.csdn.net/wcuuchina/article/details/86572311

项目请求接口的简单限流实现

定义一个限流类:

/**
 * @author wangwei
 * @version v1.0.0
 * @description
 * @date
 */
public class CacheValidate {

    private long time;
    private int invokeNum;

    public long getTime() {
        return time;
    }
    public void setTime(long time) {
        this.time = time;
    }
    public int getInvokeNum() {
        return invokeNum;
    }
    public void setInvokeNum(int invokeNum) {
        this.invokeNum = invokeNum;
    }

    /*
     *
     * 校验方法是否有效
     */
    public boolean isValidate(int limit){
        this.invokeNum = invokeNum + 1;
        if(System.currentTimeMillis() / 1000 <= time){
            System.err.println(System.currentTimeMillis() / 1000);
            if(invokeNum <= limit){
                return true;
            }
        }else{
            this.invokeNum = 1;
            this.time=System.currentTimeMillis() / 1000;
            return true;
        }
        return false;
    }
}

一个限流的实现

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;

import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * @author wangwei
 * @version v1.0.0
 * @description
 * @date
 */
@Component
public class FlowLimit {

    private static Map<String, CacheValidate> cache = new HashMap<String, CacheValidate>();

    public boolean invoke(String apiName, int sec, int limit) {
        if(apiName==null){
            return false;
        }
        CacheValidate cacheValidate = null;
        // 增加缓存中的值
        synchronized (cache) {
            cacheValidate = cache.get(apiName);
            if(cacheValidate==null){
                cacheValidate = new CacheValidate();
                cacheValidate.setTime(System.currentTimeMillis() / 1000 + sec);
                cacheValidate.setInvokeNum(1);
                cache.put(apiName, cacheValidate);
                return true;
            }
            return cacheValidate.isValidate(limit);
        }
    }

    public static void main(String[] args) {
        ExecutorService service = Executors.newFixedThreadPool(10);
        for (int i = 0; i < 10; i++) {
            try {
                Thread.sleep(700);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            service.submit(getTask());
        }
        service.shutdown();
    }

    public static Runnable getTask(){
        return new Runnable() {
            @Override
            public void run() {
                for (int i = 0; i < 10; i++) {
                    try {
                        Thread.sleep(100);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                    FlowLimit fLimit = new FlowLimit();
                    System.err.println(fLimit.invoke("aaa", 1, 1));
                }
            }
        };
    }

}

错误码定义类

public enum ErrorCode {

    SYSTEM_ERROR(500, "系统错误"),
    PARAMETER_CHECK_ERROR(400, "参数校验错误"),
    AUTH_VALID_ERROR(701, "用户权限不足"),
    UNLOGIN_ERROR(401, "用户未登录或登录状态超时失效"),

    CODE_430(430, "数据被篡改"),
    CODE_431(431, "秘钥不正确"),
    CODE_432(432, "请求太频繁,限流,请稍后再试"),

    CODE_450(450, "账户或者密码不正确"),
    CODE_451(451, "身份证号码验证失败"),

    CODE_6000(6000, "数据繁忙,请再试一次吧"),
    CODE_6001(6001, "手机号码已经注册,如果您忘记密码,请找回密码"),

    CODE_6010(6010, "银行卡已被绑定过,不可以再次绑定"),

    CODE_6800(6800, "数据处理失败"),

    CODE_9999(9999, "未知区域"),
    ;

    private final Integer value;
    private final String message;

    ErrorCode(int value, String message) {
        this.value = value;
        this.message = message;
    }

    public int getValue() {
        return value;
    }

    public String getMessage() {
        return message;
    }

    @Override
    public String toString() {
        return value.toString();
    }

    public String getCode() {
        return value.toString();
    }

    public static ErrorCode getByCode(Integer value) {
        for (ErrorCode _enum : values()) {
            if (_enum.getValue() == value) {
                return _enum;
            }
        }
        return null;
    }

}
扫描二维码关注公众号,回复: 5712088 查看本文章

自定义一个限流的注解

import java.lang.annotation.*;

/**
 * @author wangwei
 * @version v1.0.0
 * @description 请求限流类
 * @date 2019-01-19
 */
@Inherited
@Documented
@Target({ElementType.FIELD, ElementType.TYPE,ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface AccessLimit {
    //标识 指定sec时间段内的访问次数限制
    int limit() default 1;

    //标识 时间段
    int sec() default 5;
}

 拦截器类

import com.alibaba.fastjson.JSON;
import com.test.product_service.controller.base.BaseController;
import com.test.product_service.limiting.AccessLimit;
import com.test.product_service.utils.ErrorCode;
import com.test.product_service.utils.Resp;
import com.test.product_service.utils.redis.FlowLimit;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

import javax.annotation.Resource;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;

/**
 * @author wangwei
 * @version v1.0.0
 * @description 请求限流拦截器
 * @date 2019-01-19
 */
public class AccessLimitInterceptor implements HandlerInterceptor {

    //使用RedisTemplate操作redis
//    @Resource
//    public RedisTemplate<String, Integer> redisTemplate;

//    @Autowired
//    public RedisTemplate<String, Integer> redisTemplate;

//    @Autowired
//    public StringRedisTemplate stringRedisTemplate;

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        System.out.println(handler.getClass());
        Class clazz = handler.getClass();
        System.out.println(clazz.getName());
        if (handler instanceof org.springframework.web.method.HandlerMethod) {
            org.springframework.web.method.HandlerMethod handlerMethod = (org.springframework.web.method.HandlerMethod) handler;
            Method method = handlerMethod.getMethod();
            if (!method.isAnnotationPresent(AccessLimit.class)) {
                return true;
            }
            AccessLimit accessLimit = method.getAnnotation(AccessLimit.class);
            if (accessLimit == null) {
                return true;
            }

            int limit = accessLimit.limit();
            int sec = accessLimit.sec();
            String key = getIpAddress(request) + request.getRequestURI();

            boolean flag = new FlowLimit().invoke(key, sec, limit);
            if(!flag) {
                JSON.toJSONString(Resp.fail(ErrorCode.CODE_432));
                output(response, JSON.toJSONString(Resp.fail(ErrorCode.CODE_432)));
                return false;
            } else {
                return true;
            }

            /*Object data = redisTemplate.opsForValue().get(key);
            System.out.println(String.format("data : " + data.toString()));
            Integer maxLimit = redisTemplate.opsForValue().get(key);
            if (maxLimit == null) {
                //set时一定要加过期时间
                redisTemplate.opsForValue().set(key, 1, sec, TimeUnit.SECONDS);
            } else if (maxLimit < limit) {
                redisTemplate.opsForValue().set(key, maxLimit + 1, sec, TimeUnit.SECONDS);
            } else {
                output(response, "请求太频繁!");
                return false;
            }*/
        }
        return true;
    }

    public void output(HttpServletResponse response, String msg) throws IOException {
        response.setContentType("application/json;charset=UTF-8");
        ServletOutputStream outputStream = null;
        try {
            outputStream = response.getOutputStream();
            outputStream.write(msg.getBytes("UTF-8"));
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            outputStream.flush();
            outputStream.close();
        }
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {

    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {

    }

    public static String getIpAddress(HttpServletRequest request) {
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        return ip;
    }

}

多拦截器配置

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

/**
 * @author wangwei
 * @version v1.0.0
 * @description
 * @date
 */
@EnableWebMvc
@Configuration
public class MywebConfig implements WebMvcConfigurer {

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(new AccessLimitInterceptor()).addPathPatterns("/**");
//        registry.addInterceptor(new AccessLimitInterceptor()).excludePathPatterns("/**");
//        registry.addInterceptor(new MyInterceptor())
//                .addPathPatterns("/**");
    }

    @Bean
    AccessLimitInterceptor localInterceptor() {
        return new AccessLimitInterceptor();
    }
}

多拦截器配置


import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

/**
 * @author wangwei
 * @version v1.0.0
 * @description 总过滤器
 * @date 2019-10-12
 */
@Configuration
public class FilterConfig {

    @Bean
    public FilterRegistrationBean buildLogFilter() {
        FilterRegistrationBean filterRegistrationBean = new FilterRegistrationBean();
        filterRegistrationBean.setOrder(2);
        filterRegistrationBean.setFilter(new LogFilter());
        filterRegistrationBean.setName("LogFilter");
        filterRegistrationBean.addUrlPatterns("/*");
        return filterRegistrationBean;
    }

    /*
     *
     * 签名过滤器
     * @author wangwei
     * @date 2019/1/21
      * @param
     * @return org.springframework.boot.web.servlet.FilterRegistrationBean
     */
   @Bean
    public FilterRegistrationBean buildCFilter() {
        FilterRegistrationBean filterRegistrationBean = new FilterRegistrationBean();
        filterRegistrationBean.setOrder(3);
        filterRegistrationBean.setFilter(new SignFilter());
        filterRegistrationBean.setName("SignFilter");
        filterRegistrationBean.addUrlPatterns("/*");
        return filterRegistrationBean;
    }

   @Bean
    public FilterRegistrationBean buildDFilter() {
        FilterRegistrationBean filterRegistrationBean = new FilterRegistrationBean();
        filterRegistrationBean.setOrder(4);
        filterRegistrationBean.setFilter(new LoginValidateFilter());
        filterRegistrationBean.setName("LoginValidateFilter");
        filterRegistrationBean.addUrlPatterns("/*");
        return filterRegistrationBean;
    }
}

控制器配置自定义拦截注解

@AccessLimit(limit = 1,sec = 1)
    @GetMapping(value = "list", consumes = MediaType.APPLICATION_JSON_UTF8_VALUE)
    public Resp list(String time,
                     @RequestParam(name = "apptype", required = true)String apptype,
                     @RequestParam(name = "ios_version", required = true)String ios_version,
                     HttpServletRequest request){
        return Resp.success();
    }

当前定义为1s请求 一次, 1s内大于一次的请求,返回提示。

超过1s后,可正常访问。

按接口详细拦截

 完毕。

猜你喜欢

转载自blog.csdn.net/wcuuchina/article/details/86572311