Table of contents
Purpose
Create a filter in Springboot to filter all POST type requests and obtain the parameters in the body to verify whether the content is legal; this method is only applicable to POST type requests, because the parameter positions of POST and GET requests are different, so the processing method is also different. Different, if you want to intercept and obtain the verification parameters of GET type request, you can refer to the following example:
Implementation steps
1. Create a Filter to filter all requests;
2. Convert the body parameter content in the PSOT type request;
3. Process body data for verification:
3.1. Process and verify when the body data is only a json object;
3.2. Process and verify when the body data is only a json array;
3.3. When the body data is a json object and contains a json array, it is processed and verified;
3.4. When the body data is a json array and contains a json object, it is processed and verified;
Complete code
filter
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.boc.ljh.utils.Result;
import com.boc.ljh.utils.status.AppErrorCode;
import org.springframework.context.annotation.Configuration;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* @Author: ljh
* @ClassName SqlFilter
* @Description 过滤请求内容 防止sql注入
* @date 2023/8/8 16:15
* @Version 1.0
*/
@WebFilter(urlPatterns = "/*", filterName = "sqlFilter")
@Configuration
public class SqlFilter implements Filter {
@Override
public void init(FilterConfig filterConfig) {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletResponse response = (HttpServletResponse) servletResponse;
response.setContentType("application/json;charset=utf-8");
Result result = new Result();
result.setStatus(500);
result.setMessage(AppErrorCode.REQUEST_DATA_FULL.message);
String data = JSON.toJSONString(result);
BodyReaderRequestWrapper wrapper = null;
HttpServletRequest request = (HttpServletRequest) servletRequest;
if (request.getMethod().equals("POST")) {
String contentType = request.getContentType();
if ("application/json".equals(contentType)) {
wrapper = new BodyReaderRequestWrapper(request);
String requestPostStr = wrapper.getBody();
if (requestPostStr.startsWith("{")) {
//解析json对象
boolean b = resolveJSONObjectObj(requestPostStr);
if (!b) {
response.getWriter().print(data);
return;
}
} else if (requestPostStr.startsWith("[")) {
//把数据转换成json数组
JSONArray jsonArray = JSONArray.parseArray(requestPostStr);
List<String> list = JSONObject.parseArray(jsonArray.toJSONString(), String.class);
for (String str : list) {
if (str.startsWith("{")) {
//解析json对象
boolean b = resolveJSONObjectObj(requestPostStr);
if (!b) {
response.getWriter().print(data);
return;
}
} else {
boolean b = verifySql(str);
if (b) {
try {
response.getWriter().print(data);
return;
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
} else {
//application/x-www-form-urlencoded
Map<String, String[]> parameterMap = request.getParameterMap();
for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
//校验参数值是否合法
String[] value = entry.getValue();
for (String s : value) {
//校验参数值是否合法
boolean b = verifySql(s);
if (b) {
response.getWriter().print(data);
return;
}
}
}
}
}
if (wrapper == null) {
filterChain.doFilter(servletRequest, servletResponse);
} else {
filterChain.doFilter(wrapper, servletResponse);
}
}
/**
* @Author: ljh
* @Description: 对JSONObject对象进行递归参数解析
* @DateTime: 14:26 2023/8/9
* @Params:
* @Return
*/
private boolean resolveJSONObjectObj(String requestPostStr) {
boolean isover = true;
// 创建需要处理的json对象
JSONObject jsonObject = JSONObject.parseObject(requestPostStr);
// 获取所有的参数key
Set<String> keys = jsonObject.keySet();
if (keys.size() > 0) {
for (String key : keys) {
//获取参数名称
String value;
if (jsonObject.get(key) != null) {
value = String.valueOf(jsonObject.get(key));
//当value为数组时
if (value.startsWith("[")) {
//把数据转换成json数组
JSONArray jsonArray = JSONArray.parseArray(value);
for (Object o : jsonArray) {
if (o.toString().startsWith("{")) {
//解析json对象
boolean b = resolveJSONObjectObj(o.toString());
if (!b) {
isover = false;
break;
}
} else {
boolean b = verifySql(value);
if (b) {
isover = false;
break;
}
}
}
} else if (value.startsWith("{")) {
boolean b = resolveJSONObjectObj(value);
if (!b) {
isover = false;
break;
}
} else {
//校验参数值是否合法
boolean b = verifySql(value);
if (b) {
isover = false;
break;
}
}
}
}
}
return isover;
}
@Override
public void destroy() {
}
/**
* @Author: ljh
* @Description: 校验参数非法字符
* @DateTime: 14:26 2023/8/9
* @Params:
* @Return
*/
public boolean verifySql(String parameter) {
String s = parameter.toLowerCase();
// 过滤掉的sql关键字,特殊字符前面需要加\\进行转义
String badStr =
"select|update|and|or|delete|insert|truncate|char|into|substr|ascii|declare|exec|count|master|into|drop|execute|table|" +
"char|declare|sitename|xp_cmdshell|like|from|grant|use|group_concat|column_name|" +
"information_schema.columns|table_schema|union|where|order|by|" +
"'\\*|\\;|\\-|\\--|\\+|\\,|\\//|\\/|\\%|\\#";
//使用正则表达式进行匹配
boolean matches = s.matches(badStr);
return matches;
}
}
Parse body data tool class
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
/**
* @Author: ljh
* @ClassName BodyReaderRequestWrapper
* @Description 解析body数据
* @date 2023/8/8 16:14
* @Version 1.0
*/
public class BodyReaderRequestWrapper extends HttpServletRequestWrapper {
private final String body;
public String getBody() {
return body;
}
/**
* 取出请求体body中的参数(创建对象时执行)
*
* @param request
*/
public BodyReaderRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
StringBuilder sb = new StringBuilder();
InputStream ins = request.getInputStream();
BufferedReader isr = null;
try {
if (ins != null) {
isr = new BufferedReader(new InputStreamReader(ins));
char[] charBuffer = new char[128];
int readCount;
while ((readCount = isr.read(charBuffer)) != -1) {
sb.append(charBuffer, 0, readCount);
}
}
} finally {
if (isr != null) {
isr.close();
}
}
sb.toString();
body = sb.toString();
}
@Override
public BufferedReader getReader() {
return new BufferedReader(new InputStreamReader(this.getInputStream()));
}
@Override
public ServletInputStream getInputStream() {
final ByteArrayInputStream byteArrayIns = new ByteArrayInputStream(body.getBytes());
return new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public int read() {
return byteArrayIns.read();
}
};
}
}