springboot中的拦截器interceptor和过滤器filter,多次获取request参数

大家好,我是烤鸭:
这是一篇关于springboot的拦截器(interceptor)和过滤器(Filter)。
先说一下过滤器和拦截器。
区别:
1. servlet请求,顺序:Filter ——> interceptor。
2. Filter的作用是对所有进行过滤,包括接口或者静态资源,interceptor 仅拦截 请求。
3. Filter对请求或者资源进行过滤,筛选合适的请求或者资源。interceptor,仅对不符合的请求拦截。
4. Filter基于回调函数,我们需要实现的filter接口中doFilter方法就是回调函数,而interceptor则基于 
    java本身的反射机制,这是两者最本质的区别。
5. Filter是依赖于servlet容器的,即只能在servlet容器中执行,很显然没有servlet容器就无法来回调
    doFilter方法。而interceptor与servlet容器无关。

后面代码较多,不太适合看。

提前总结一下,我这里的过滤器和拦截器的使用:

filter: 目的就是可以以流的方式多次获取请求参数。
Interceptor: 对回调接口进行统一的验证签名。
回调接口都需要验证签名,而且签名规则一样,所以想拿到拦截器处理。

如果签名或者ip地址不符合条件,直接就返回了。而具体的接口只要专注业务处理,不需要验证签名了。

下面贴一下在springboot中的使用:

1. filter:

InterfaceFilter.java

package com.test.test.filter;

import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

@WebFilter(urlPatterns = "/*", filterName = "InterfaceFilter")
public class InterfaceFilter implements Filter{
	private static Logger log = LoggerFactory.getLogger(InterfaceFilter.class);
	 @Override
	    public void init(FilterConfig filterConfig) throws ServletException {

	    }

	    @Override
	    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
	        HttpServletRequest req = (HttpServletRequest) request;
	        HttpServletResponse res = (HttpServletResponse) response;
	        try {
	            if ("POST".equals(req.getMethod().toUpperCase())) {
	            	// 获取请求参数
	                byte[] bytes = IOUtils.toByteArray(request.getInputStream());
	                String params = new String(bytes, req.getCharacterEncoding());
	                ThreadCache.setPostRequestParams(params);
	                log.info("filer-post请求参数:[params={}]", params);
	            } else {
	                log.info("非post请求");
	            }

	                chain.doFilter(request, response);
	        } catch (Exception e) {
	            log.error(e.getMessage(), e);
	        }
	    }

	    @Override
	    public void destroy() {

	    }
}

ThreadCache.java:

package com.test.test.filter;
public class ThreadCache {
	// ThreadLocal里只存储了简单的String对象,也可以自己定义对象,存储更加复杂的参数
    private static ThreadLocal<String> threadLocal = new ThreadLocal<String>();

    public static String getPostRequestParams(){
        return threadLocal.get();
    }

    public static void setPostRequestParams(String postRequestParams){
        threadLocal.set(postRequestParams);
    }

    public static void removePostRequestParams(){
        threadLocal.remove();
    }
}

说一下WebFilter注解。

urlPatterns指的是过滤哪个路径。跟在xml中配置是一样的作用。

简单说一下,我这个过滤器的作用。目的就是可以以流的方式多次获取请求参数。
正常requst.getInputStream只能获取一次。但是我在过滤器中将他的参数放到threadlocal中,
这样当前请求中就可以获取到请求参数了。
另外说一下:
有两种方式实现多次获取request参数。
1. 创建threadlocal类,将request或者request的参数放进去。
2. 创建wrapper类,类似装饰者模式,对request对象进行处理。getInputStream之后将流重新set进去。

推荐一篇多次获取request参数的博客:

https://www.cnblogs.com/endstart/p/6196807.html

2. Interceptor:

扫描二维码关注公众号,回复: 996941 查看本文章

PlatformInterceptor.java:

package com.test.test.interceptor;

import java.util.Map;

import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

import com.alibaba.fastjson.JSONObject;
import com.test.test.config.ConfigProperties;
import com.test.test.constants.IMsgEnum;
import com.test.test.filter.ThreadCache;
import com.test.test.resp.BaseResp;
import com.test.test.security.SignCheck;
import com.test.test.utils.DESTools;
import com.test.test.utils.JsonUtil;
import com.test.test.utils.LogUtils;
import com.test.test.utils.NetworkUtil;
import com.test.test.utils.ReflectUtil;

/**
 * ClassName: PlatformInterceptor date: 2015年12月30日 下午2:13:24 Description: 拦截器
 * 
 * @author xiaozhan
 * @version
 * @since JDK 1.8
 */
@Component
public class PlatformInterceptor implements HandlerInterceptor {

	private static final Log logger = LogFactory.getLog(PlatformInterceptor.class);

	@Autowired
	private SignCheck signCheck;

	@Autowired
	private ConfigProperties configProperties;

	@Override
	public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
			throws Exception {
		logger.info(LogUtils.getRequestLog(request));
		// 获取自定义注解
		String allowOrigin = null;
		String servletPath = request.getServletPath();
		boolean isDeprecated = false;
		BaseResp baseResp = new BaseResp();
		ServletOutputStream out = response.getOutputStream();

		if (handler instanceof HandlerMethod) {
			HandlerMethod handlerMethod = (HandlerMethod) handler;

			Deprecated deprecated = handlerMethod.getMethodAnnotation(Deprecated.class);
			if (deprecated != null) {
				isDeprecated = true;
			}
		}

		String method = request.getMethod();
		if (!method.equals(RequestMethod.POST.name())) {
			baseResp.getMsg().setRstcode(IMsgEnum.PARAM_REQUEST_METHOD_FALSE.getMsgCode());
			baseResp.getMsg().setRsttext(IMsgEnum.PARAM_REQUEST_METHOD_FALSE.getMsgText());
			logger.info("----- " + IMsgEnum.PARAM_REQUEST_METHOD_FALSE.getMsgText() + " -----");
			out.write(JSONObject.toJSONString(baseResp).getBytes("UTF-8"));
			return false;
		}

		String clientIp = NetworkUtil.getIpAddress(request);
		logger.info("------ client Ip is ---》" + clientIp);
		// 判断是否是ip白名单
		if (!signCheck.checkIpAddress(clientIp)) {
			baseResp.getMsg().setRstcode(IMsgEnum.PARAM_IP_ADDRESS_FALSE.getMsgCode());
			baseResp.getMsg().setRsttext(IMsgEnum.PARAM_IP_ADDRESS_FALSE.getMsgText());
			logger.info("----- " + IMsgEnum.PARAM_IP_ADDRESS_FALSE.getMsgText() + " -----");
			out.write(JSONObject.toJSONString(baseResp).getBytes("UTF-8"));
			return false;
		}
		// 验证签名
		String params = ThreadCache.getPostRequestParams();

		logger.info("interceptor-post请求参数:[params={}]" + params);
		Map<String, Object> map = ReflectUtil.getDecodeParamMap(params);
		String sign = (String) map.get("sign");
		if (map.containsKey("sign")) {
			map.remove("sign");
		}
		// 签名校验
		if (!SignCheck.checkSign(map, sign, configProperties.getPrivateKey())) {
			baseResp.getMsg().setRstcode(IMsgEnum.PARAM_SIGN_FALSE.getMsgCode());
			baseResp.getMsg().setRsttext(IMsgEnum.PARAM_SIGN_FALSE.getMsgText());
			logger.info("----- " + IMsgEnum.PARAM_SIGN_FALSE.getMsgText() + " -----");
			out.write(JSONObject.toJSONString(baseResp).getBytes("UTF-8"));
			return false;
		}
		try {
			if (isDeprecated) {
				logger.error(LogUtils.getCommLog(String.format("该接口已停止使用,%s", servletPath)));
			}
		} catch (Exception e) {
			baseResp.getMsg().setRstcode(IMsgEnum.PARAMETER_INVALID.getMsgCode());
			baseResp.getMsg().setRsttext(IMsgEnum.PARAMETER_INVALID.getMsgText());
			logger.info("----- " + IMsgEnum.PARAMETER_INVALID.getMsgText() + " -----");
			out.write(JSONObject.toJSONString(baseResp).getBytes("UTF-8"));
			return false;
		}
		return true;
	}

	@Override
	public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler,
			ModelAndView modelAndView) throws Exception {

	}

	@Override
	public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex)
			throws Exception {
	}

}
可以针对某些具体的请求进行拦截。我这里没做限制,所有的请求都走。

如果想拦截 特定的请求,判断一下request.getContextPath()是否包含某个路径。

我这里用到拦截器的作用是对回调接口进行统一的验证签名。
回调接口都需要验证签名,而且签名规则一样,所以想拿到拦截器处理。
如果签名或者ip地址不符合条件,直接就返回了。而具体的接口只要专注业务处理,不需要验证签名了。
分享一下用的工具类: ip和签名校验:

signCheck.java:

package com.test.test.security;


import java.util.Arrays;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.stereotype.Component;

import com.alibaba.fastjson.JSONObject;
import com.test.test.config.ConfigProperties;
import com.test.test.interceptor.PlatformInterceptor;
import com.test.test.utils.Signature;
import org.springframework.stereotype.Service;

/** 
* @author gmwang E-mail: 
* @version 创建时间:2018年3月1日 上午10:35:47 
* 类说明 :校验签名和ip
*/
@Service(value="signCheck")
public class SignCheck{
	private static final Log logger = LogFactory.getLog(PlatformInterceptor.class);
	
	@Autowired
	private ConfigProperties configProperties;
	
	@Bean
    public SignCheck getSignCheck(){
        return new SignCheck();
    }
	
	/**
	 * 校验签名
	 * @Description:
	 * @author:
	 * @throws IllegalAccessException 
	 * @time:2018年3月1日 上午10:38:09
	 */
	public static Boolean checkSign(Map params, String sign,String privateKey) throws IllegalAccessException {
		if(StringUtils.isBlank(sign)) {
			logger.info("*********************sign is null*********************");
			return false;
		}else {
			String signAfter = Signature.getSign(params,privateKey);
			System.out.println("sign:"+sign);
			System.out.println("signAfter:"+signAfter);
			if(!sign.equals(signAfter)) {
				logger.info("*********************sign is not equal signAfter*********************");
				return false;
			}
		}
		return true;
	}
	/**
	 * 校验ip
	 * @Description:
	 * @author:
	 * @throws IllegalAccessException 
	 * @time:2018年3月1日 上午10:38:09
	 */
	public Boolean checkIpAddress(String ip) throws IllegalAccessException {
		String ipWhite = configProperties.getRequestUrl();
		System.out.println(ipWhite);
		String[] ipWhiteArray = ipWhite.split(",");
		List<String> ipWhiteList = Arrays.asList(ipWhiteArray);
		if(!ipWhiteList.contains(ip)) {
			logger.info("*********************ip is not in ipWhiteList*********************");
			return false;
		}
		return true;
	}
}
获取ip地址工具类:

NetworkUtil.java:

package com.test.test.utils;

import java.io.IOException;

import javax.servlet.http.HttpServletRequest;

import org.apache.log4j.Logger;

/** 
 * 常用获取客户端信息的工具 
 *  
 */  
public final class NetworkUtil {  
    /** 
     * Logger for this class 
     */  
    private static Logger logger = Logger.getLogger(NetworkUtil.class);  
  
    /** 
     * 获取请求主机IP地址,如果通过代理进来,则透过防火墙获取真实IP地址; 
     *  
     * @param request 
     * @return 
     * @throws IOException 
     */  
    public final static String getIpAddress(HttpServletRequest request) throws IOException {  
        // 获取请求主机IP地址,如果通过代理进来,则透过防火墙获取真实IP地址  
  
        String ip = request.getHeader("X-Forwarded-For");  
        if (logger.isInfoEnabled()) {  
            logger.info("getIpAddress(HttpServletRequest) - X-Forwarded-For - String ip=" + ip);  
        }  
  
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {  
            if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {  
                ip = request.getHeader("Proxy-Client-IP");  
                if (logger.isInfoEnabled()) {  
                    logger.info("getIpAddress(HttpServletRequest) - Proxy-Client-IP - String ip=" + ip);  
                }  
            }  
            if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {  
                ip = request.getHeader("WL-Proxy-Client-IP");  
                if (logger.isInfoEnabled()) {  
                    logger.info("getIpAddress(HttpServletRequest) - WL-Proxy-Client-IP - String ip=" + ip);  
                }  
            }  
            if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {  
                ip = request.getHeader("HTTP_CLIENT_IP");  
                if (logger.isInfoEnabled()) {  
                    logger.info("getIpAddress(HttpServletRequest) - HTTP_CLIENT_IP - String ip=" + ip);  
                }  
            }  
            if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {  
                ip = request.getHeader("HTTP_X_FORWARDED_FOR");  
                if (logger.isInfoEnabled()) {  
                    logger.info("getIpAddress(HttpServletRequest) - HTTP_X_FORWARDED_FOR - String ip=" + ip);  
                }  
            }  
            if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {  
                ip = request.getRemoteAddr();  
                if (logger.isInfoEnabled()) {  
                    logger.info("getIpAddress(HttpServletRequest) - getRemoteAddr - String ip=" + ip);  
                }  
            }  
        } else if (ip.length() > 15) {  
            String[] ips = ip.split(",");  
            for (int index = 0; index < ips.length; index++) {  
                String strIp = (String) ips[index];  
                if (!("unknown".equalsIgnoreCase(strIp))) {  
                    ip = strIp;  
                    break;  
                }  
            }  
        }  
        return ip;  
    }  
}  
加密解密工具类:

Signature.java:

package com.test.test.utils;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import com.alibaba.fastjson.JSONObject;


/**
 * User: 
 * Date: 2015/8/26
 * Time: 15:23
 */
public class Signature {
    /**
     * 签名算法
     * @param o 要参与签名的数据对象
     * @return 签名
     * @throws IllegalAccessException
     *规则:签名,将筛选的参数按照第一个字符的键值ASCII码递增排序(字母升序排序),如果遇到相同字符则按照第二个字符的键值ASCII码递增排序,以此类推,形成key=value& * skey的字符串MD5加密; (必填)
     */
    public static String getSign(Object o) throws IllegalAccessException {
        ArrayList<String> list = new ArrayList<String>();
        Class cls = o.getClass();
        Field[] fields = cls.getDeclaredFields();
        for (Field f : fields) {
            f.setAccessible(true);
            if (f.get(o) != null && f.get(o) != "") {
                list.add(f.getName() + "=" + f.get(o) + "&");
            }
        }
        int size = list.size();
        String [] arrayToSort = list.toArray(new String[size]);
        Arrays.sort(arrayToSort, String.CASE_INSENSITIVE_ORDER);
        StringBuilder sb = new StringBuilder();
        for(int i = 0; i < size; i ++) {
            sb.append(arrayToSort[i]);
        }
        String result = sb.toString();
        System.out.println(result);
        result = MD5Util.MD5Encode(result).toUpperCase();
        System.out.println(result);
        return result;
    }

    public static String getSign(Map<String,Object> map,String privateKey){
        ArrayList<String> list = new ArrayList<String>();
        for(Map.Entry<String,Object> entry:map.entrySet()){
            if(entry.getValue()!=""){
                list.add(entry.getKey() + "=" + entry.getValue() + "&");
            }
        }
        int size = list.size();
        String [] arrayToSort = list.toArray(new String[size]);
        Arrays.sort(arrayToSort, String.CASE_INSENSITIVE_ORDER);
        StringBuilder sb = new StringBuilder();
        for(int i = 0; i < size; i ++) {
            sb.append(arrayToSort[i]);
        }
        sb.append(privateKey);
        String result = sb.toString();
        System.out.println(result);
        result = MD5Util.MD5Encode(result).toUpperCase();
        return result;
    }

    public static void main(String[] args) {
    	/*
    	 * */
        Map<String,Object> map=new HashMap();
        map.put("uuid","PC0000000056");
        String result= getSign(map,"aaaaa!aaa");
        System.out.println(result);
        map.put("sign", result);//sign
        DESTools desTools = new DESTools();
        String s = JSONObject.toJSONString(map);
        System.out.println(s);
		String param= desTools.getEncString(s);
		String str= HttpUtil.doPost("http://localhost:8111/test/test", param);
		System.out.println(str);
    }
}

DESTools.java:

package com.test.test.utils;
import java.security.Key;

import javax.crypto.Cipher;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.DESKeySpec;
import org.apache.commons.codec.binary.Base64;


public class DESTools {
	
	
	public static  DESTools instance;
	
	public static DESTools   getInstace()
	{
		if(instance == null)
		{
			instance = new DESTools();
		}
		return instance;
	}

	Key key;

	/**
	 * 密钥
	 */
	private static byte[] BOSS_SECRET_KEY = { 0x0b, 0x13, (byte) 0xe7,
			(byte) 0xb2, 0x51, 0x0d, 0x75, (byte) 0xc2, 0x4e, (byte) 0xdd,
			(byte) 0x4b, (byte) 0x51, 0x24, 0x36, (byte) 0xa8, (byte) 0x28,
			0x0b, 0x13, (byte) 0xe2, (byte) 0xb2, 0x31, 0x0d, 0x75, (byte) 0xc1 };

	public DESTools() {
		setKey(BOSS_SECRET_KEY);
	}

	/**
	 * 根据参数生成KEY
	 */
	public void setKey(byte[] strKey) {
		try {
			DESKeySpec dks = new DESKeySpec(BOSS_SECRET_KEY);
			SecretKeyFactory keyFactory;
			keyFactory = SecretKeyFactory.getInstance("DES");
			this.key = keyFactory.generateSecret(dks);
		} catch (Exception e) {
			throw new RuntimeException(
					"Error initializing DESTOOLS class. Cause: " + e);
		}
	}

	/**
	 * 加密String明文输入,String密文输出
	 */
	public String getEncString(String strMing) {
		byte[] byteMi = null;
		byte[] byteMing = null;
		String strMi = "";
		Base64 base64en = new Base64();
		try {
			byteMing = strMing.getBytes("UTF8");
			byteMi = this.getEncCode(byteMing);
			strMi = base64en.encodeAsString(byteMi);
		} catch (Exception e) {
			throw new RuntimeException(
					"Error initializing DESTOOLS class. Cause: " + e);
		} finally {
			base64en = null;
			byteMing = null;
			byteMi = null;
		}
		return strMi;
	}

	/**
	 * 解密 以String密文输入,String明文输出
	 * @param strMi
	 * @return
	 */
	public String getDesString(String strMi) {
		Base64 base64De = new Base64();
		byte[] byteMing = null;
		byte[] byteMi = null;
		String strMing = "";
		try {
			byteMi = base64De.decode(strMi);
			byteMing = this.getDesCode(byteMi);
			strMing = new String(byteMing, "UTF8");
		} catch (Exception e) {
			throw new RuntimeException("Error initializing DESTOOLS class. Cause: " + e);
		} finally {
			base64De = null;
			byteMing = null;
			byteMi = null;
		}
		return strMing;
	}

	/**
	 * 加密以byte[]明文输入,byte[]密文输出
	 * @param byteS
	 * @return
	 */
	private byte[] getEncCode(byte[] byteS) {
		byte[] byteFina = null;
		Cipher cipher;
		try {
			cipher = Cipher.getInstance("DES");
			cipher.init(Cipher.ENCRYPT_MODE, key);
			byteFina = cipher.doFinal(byteS);
		} catch (Exception e) {
			throw new RuntimeException(
					"Error initializing DESTOOLS class. Cause: " + e);
		} finally {
			cipher = null;
		}
		return byteFina;
	}

	/**
	 * 解密以byte[]密文输入,以byte[]明文输出
	 * @param byteD
	 * @return
	 */
	private byte[] getDesCode(byte[] byteD) {
		Cipher cipher;
		byte[] byteFina = null;
		try {
			cipher = Cipher.getInstance("DES");
			cipher.init(Cipher.DECRYPT_MODE, key);
			byteFina = cipher.doFinal(byteD);
		} catch (Exception e) {
			throw new RuntimeException(
					"Error initializing DESTOOLS class. Cause: " + e);
		} finally {
			cipher = null;
		}
		return byteFina;
	}

}

获取yml文件的值:
ConfigProperties.java:

package com.test.test.config;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Component;

@Configuration
@Component
public class ConfigProperties {
	
	@Value("${test.test.privateKey}")
	private String privateKey;
	@Value("${test.test.requestUrl}")
	private String requestUrl;

	public String getPrivateKey() {
		return privateKey;
	}

	public void setPrivateKey(String privateKey) {
		this.privateKey = privateKey;
	}

	public String getRequestUrl() {
		return requestUrl;
	}

	public void setRequestUrl(String requestUrl) {
		this.requestUrl = requestUrl;
	}

}

yml文件如图:


猜你喜欢

转载自blog.csdn.net/angry_mills/article/details/79456137