SpringBoot实现XSS过滤
导入hutool工具库 hutool官方参考文档。
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.1.4</version>
</dependency>
防XSS攻击过滤器
package com.xss.filter;
import cn.hutool.core.util.StrUtil;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* 防止XSS攻击的过滤器
*/
public class XssFilter implements Filter {
/**
* 排除链接
*/
private List<String> excludes = new ArrayList<>();
/**
* xss过滤开关
*/
private boolean enabled = false;
@Override
public void init(FilterConfig filterConfig) {
String tempExcludes = filterConfig.getInitParameter("excludes");
String tempEnabled = filterConfig.getInitParameter("enabled");
if (StrUtil.isNotEmpty(tempExcludes)) {
String[] url = tempExcludes.split(",");
Collections.addAll(excludes, url);
}
if (StrUtil.isNotEmpty(tempEnabled)) {
enabled = Boolean.valueOf(tempEnabled);
}
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) request;
if (handleExcludeUrl(req)){
chain.doFilter(request, response);
return;
}
XssHttpServletRequestWrapper xssRequest = new XssHttpServletRequestWrapper((HttpServletRequest) request);
chain.doFilter(xssRequest, response);
}
/**
* 判断当前路径是否需要过滤
*/
private boolean handleExcludeUrl(HttpServletRequest request) {
if (!enabled) {
return true;
}
if (excludes == null || excludes.isEmpty()) {
return false;
}
String url = request.getServletPath();
for (String pattern : excludes) {
Pattern p = Pattern.compile("^" + pattern);
Matcher m = p.matcher(url);
if (m.find()) {
return true;
}
}
return false;
}
}
继承HttpServletRequestWrapper
package com.xss.filter;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.StrUtil;
import com.xss.util.EscapeUtil;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
/**
* XSS过滤处理
*/
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
XssHttpServletRequestWrapper(HttpServletRequest request){
super(request);
}
@Override
public String[] getParameterValues(String name) {
String[] values = super.getParameterValues(name);
if (values != null) {
int length = values.length;
String[] escapseValues = new String[length];
for (int i = 0; i < length; i++) {
// 防xss攻击和过滤前后空格
escapseValues[i] = EscapeUtil.escape(values[i]).trim();
}
return escapseValues;
}
return super.getParameterValues(name);
}
@Override
public ServletInputStream getInputStream() throws IOException {
// 非json类型,直接返回
if (!isJsonRequest()) {
return super.getInputStream();
}
// 为空,直接返回
String json = IoUtil.read(super.getInputStream(), "utf-8");
if (StrUtil.isEmpty(json)) {
return super.getInputStream();
}
// xss过滤
json = EscapeUtil.escape(json).trim();
final ByteArrayInputStream bis = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8));
return new ServletInputStream() {
@Override
public boolean isFinished(){
return true;
}
@Override
public boolean isReady(){
return true;
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public int read(){
return bis.read();
}
};
}
/**
* 是否是Json请求
*/
private boolean isJsonRequest() {
String header = super.getHeader(HttpHeaders.CONTENT_TYPE);
return MediaType.APPLICATION_JSON_VALUE.equalsIgnoreCase(header)
|| MediaType.APPLICATION_JSON_UTF8_VALUE.equalsIgnoreCase(header);
}
}
转义工具类
package com.xss.util;
import cn.hutool.core.util.StrUtil;
/**
* 转义和反转义工具类
*/
public class EscapeUtil {
private static final char[][] TEXT = new char[64][];
static {
for (int i = 0; i < 64; i++) {
TEXT[i] = new char[] { (char) i };
}
// special HTML characters
TEXT['\''] = "'".toCharArray(); // 单引号
TEXT['"'] = """.toCharArray(); // 单引号
TEXT['&'] = "&".toCharArray(); // &符
TEXT['<'] = "<".toCharArray(); // 小于号
TEXT['>'] = ">".toCharArray(); // 大于号
}
/**
* Escape编码 转义文本中的HTML字符为安全的字符
*/
public static String escape(String text) {
int len;
if ((text == null) || ((len = text.length()) == 0)) {
return StrUtil.EMPTY;
}
StringBuilder buffer = new StringBuilder(len + (len >> 2));
char c;
for (int i = 0; i < len; i++) {
c = text.charAt(i);
if (c < 64) {
buffer.append(TEXT[c]);
}
else {
buffer.append(c);
}
}
return buffer.toString();
}
/**
* Escape解码 还原被转义的HTML特殊字符
*/
public static String unescape(String content) {
if (StrUtil.isEmpty(content)) {
return content;
}
StringBuilder tmp = new StringBuilder(content.length());
int lastPos = 0, pos = 0;
char ch;
while (lastPos < content.length()) {
pos = content.indexOf("%", lastPos);
if (pos == lastPos) {
if (content.charAt(pos + 1) == 'u') {
ch = (char) Integer.parseInt(content.substring(pos + 2, pos + 6), 16);
tmp.append(ch);
lastPos = pos + 6;
}
else {
ch = (char) Integer.parseInt(content.substring(pos + 1, pos + 3), 16);
tmp.append(ch);
lastPos = pos + 3;
}
}
else {
if (pos == -1) {
tmp.append(content.substring(lastPos));
lastPos = content.length();
}
else {
tmp.append(content.substring(lastPos, pos));
lastPos = pos;
}
}
}
return tmp.toString();
}
}
Filter配置类
package com.xss.config;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.DispatcherType;
import cn.hutool.core.util.StrUtil;
import com.xss.filter.XssFilter;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* Filter配置
*/
@Configuration
public class FilterConfig {
@Value("${xss.enabled}")
private String enabled;
@Value("${xss.excludes}")
private String excludes;
@Value("${xss.urlPatterns}")
private String urlPatterns;
@SuppressWarnings({ "rawtypes", "unchecked" })
@Bean
public FilterRegistrationBean xssFilterRegistration() {
FilterRegistrationBean registration = new FilterRegistrationBean();
registration.setDispatcherTypes(DispatcherType.REQUEST);
registration.setFilter(new XssFilter());
//添加过滤路径
registration.addUrlPatterns(StrUtil.split(urlPatterns, ","));
registration.setName("xssFilter");
registration.setOrder(Integer.MAX_VALUE);
//设置初始化参数
Map<String, String> initParameters = new HashMap<String, String>();
initParameters.put("excludes", excludes);
initParameters.put("enabled", enabled);
registration.setInitParameters(initParameters);
return registration;
}
}
测试
修改全局配置文件,开启xss过滤、添加过滤和排除路径。
# 防止XSS攻击
xss:
# 过滤开关
enabled: true
# 排除链接(多个用逗号分隔)
excludes: /open/*
# 匹配链接
urlPatterns: /*
用于测试的controller。
package com.xss.controller;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
public class TestController {
@PostMapping("hello")
public String hello(String content){
return content;
}
@PostMapping("open/hello")
public String open(String content){
return content;
}
}
使用postman进行测试。
请求被过滤接口:http://localhost:8080/hello?content=<script>alert(1);</script>
。
请求未被过滤接口:http://localhost:8080/open/hello?content=<script>alert(1);</script>
。