插件类,继承MyBatis的Interceptor 接口
原理分析可参考:MyBatis源码笔记(八) – 插件实现原理
@Intercepts(
@Signature(type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})
)
public class PageInterceptor implements Interceptor {
private static final List<ResultMapping> EMPTYRESULTMAPPING = new
ArrayList<ResultMapping>(0);
private Dialect dialect;
private Field additionalParametersField;
@Override
public Object intercept(Invocation invocation) throws Throwable {
//获取拦截方法的参数
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object parameterObject = args[1];
RowBounds rowBounds = (RowBounds) args[2];
//调用方法判断是否需要进行分页查询,如果不需要,直接返回结果
if (!dialect.skip(ms.getId(), parameterObject, rowBounds)) {
//进行分页查询
ResultHandler resultHandler = (ResultHandler) args[3];
//当前目标对象
Executor executor = (Executor) invocation.getTarget();
BoundSql boundSql = ms.getBoundSql(parameterObject);
//反射获取动态参数
Map<String, Object> additionalParameters = (Map<String, Object>) additionalParametersField.get(boundSql);
//判断是否需要进行count查询
if (dialect.beforeCount(ms.getId(), parameterObject, rowBounds)) {
//根据当前的ms创建一个返回值为Long类型的ms
MappedStatement countMs = newMappedStatement(ms, Long.class);
//创建count查询缓存key
CacheKey countKey = executor.createCacheKey(countMs, parameterObject, rowBounds, boundSql);
//调用方言获取count sql
String countSql = dialect.getCountSql(boundSql, parameterObject, rowBounds, countKey);
BoundSql countBoundSql = new BoundSql(ms.getConfiguration(), countSql, boundSql.getParameterMappings(), parameterObject);
//当使用动态SQL时,可能会产生临时参数,需要手动设置到boundsql中
for (String key : additionalParameters.keySet()) {
countBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
}
//执行count语句
List<Object> countResultList = executor.query(countMs, parameterObject,
RowBounds.DEFAULT, resultHandler, countKey, countBoundSql);
Long count = (Long) countResultList.get(0);
dialect.afterCount(count, parameterObject, rowBounds);
if (count == 0L) {
//如果总数为0,返回空查询结果
return dialect.afterPage(new ArrayList(), parameterObject, rowBounds);
}
}
//判断是否进行分页查询
if (dialect.beforePage(ms.getId(), parameterObject, rowBounds)) {
CacheKey pageKey = executor.createCacheKey(ms, parameterObject, rowBounds, boundSql);
String pageSql = dialect.getPageSql(boundSql, parameterObject, rowBounds, pageKey);
BoundSql pageBoundSql = new BoundSql(ms.getConfiguration(), pageSql,
boundSql.getParameterMappings(), parameterObject);
//当使用动态SQL时,可能会产生临时参数,需要手动设置到boundsql中
for (String key : additionalParameters.keySet()) {
pageBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
}
//执行分页查询
List<Object> resultList = executor.query(ms, parameterObject,
RowBounds.DEFAULT, resultHandler, pageKey, pageBoundSql);
return dialect.afterPage(resultList, parameterObject, rowBounds);
}
}
//返回默认结果
return invocation.proceed();
}
/**
* 根据现有ms创建一个新的返回值类型
*
* @param ms
* @param returnType
* @return
*/
private MappedStatement newMappedStatement(MappedStatement ms, Class<?> returnType) {
MappedStatement.Builder builder =
new MappedStatement.Builder(ms.getConfiguration(), ms.getId() + "_Count",
ms.getSqlSource(), ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
StringBuilder keyProperties = new StringBuilder();
for (String keyProperty : ms.getKeyProperties()) {
keyProperties.append(keyProperties).append(",");
}
keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
List<ResultMap> resultMaps = new ArrayList<ResultMap>();
ResultMap resultMap = new ResultMap.Builder(ms.getConfiguration(), ms.getId(), returnType, EMPTYRESULTMAPPING).build();
resultMaps.add(resultMap);
builder.resultMaps(resultMaps);
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
String dialectClass = properties.getProperty("dialect");
try {
dialect = (Dialect) Class.forName(dialectClass).newInstance();
} catch (Exception e) {
throw new RuntimeException("使用PageInterceptor分页插件,必须要设置dialect");
}
dialect.setProperties(properties);
try {
additionalParametersField = BoundSql.class.getDeclaredField("additionalParameters");
additionalParametersField.setAccessible(true);
} catch (NoSuchFieldException e) {
throw new RuntimeException(e);
}
}
}
Dialect接口
public interface Dialect {
/**
* 跳过count和分页查询
*
* @param msId 执行的mybatis方法的全名
* @param parameterObject 方法参数
* @param rowBounds 分页参数
* @return true 跳过,返回默认结果; false则执行分页查询
*/
boolean skip(String msId, Object parameterObject, RowBounds rowBounds);
/**
* 执行分页查询前,返回true执行count会进行count查询,返回false会继续下面的beforePage判断
*
* @param msId
* @param parameterObject
* @param rowBounds
* @return
*/
boolean beforeCount(String msId, Object parameterObject, RowBounds rowBounds);
/**
* 生成count查询语句
*
* @param boundSql
* @param parameterObject
* @param rowBounds
* @param countKey
* @return
*/
String getCountSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey countKey);
/**
* 执行完count后
*
* @param count
* @param parameterObject
* @param rowBounds
*/
void afterCount(long count, Object parameterObject, RowBounds rowBounds);
/**
* 执行分页前,返回true会执行分页查询,返回false会返回默认查询结果
*
* @param msId
* @param parameterObject
* @param rowBounds
* @return
*/
boolean beforePage(String msId, Object parameterObject, RowBounds rowBounds);
/**
* 生成分页查询sql
*
* @param boundSql
* @param parameterObject
* @param rowBounds
* @param pageKey
* @return
*/
String getPageSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey pageKey);
/**
* 分页查询后,处理分页结果,拦截器中直接return该方法的返回值
*
* @param pageList
* @param parameterObject
* @param rowBounds
*/
Object afterPage(List pageList, Object parameterObject, RowBounds rowBounds);
/**
* 设置参数
*
* @param properties 插件属性
*/
void setProperties(Properties properties);
}
Dialect的MySql数据库实现类
public class MySqlDialect implements Dialect {
@Override
public boolean skip(String msId, Object parameterObject, RowBounds rowBounds) {
//没有分页参数,会用RowBounds.DEFAULT
if (rowBounds == RowBounds.DEFAULT) {
return true;
}
return false;
}
@Override
public boolean beforeCount(String msId, Object parameterObject, RowBounds rowBounds) {
//只有使用PageRowBounds才能记录总数
if (rowBounds instanceof PageRowBounds) {
return true;
}
return false;
}
@Override
public String getCountSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey countKey) {
//简单嵌套Mysql count查询
return "select count(*) from (" + boundSql.getSql() + ") temp";
}
@Override
public void afterCount(long count, Object parameterObject, RowBounds rowBounds) {
//设置总数
((PageRowBounds) rowBounds).setTotal(count);
}
@Override
public boolean beforePage(String msId, Object parameterObject, RowBounds rowBounds) {
if (rowBounds != RowBounds.DEFAULT) {
return true;
}
return false;
}
@Override
public String getPageSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey pageKey) {
pageKey.update("RowBounds");
return boundSql.getSql() + " limit " + rowBounds.getOffset() + "," + rowBounds.getLimit();
}
@Override
public Object afterPage(List pageList, Object parameterObject, RowBounds rowBounds) {
//直接返回结果
return pageList;
}
@Override
public void setProperties(Properties properties) {
//没有其他参数
}
}
分页查询带总数的实体类,继承MyBatis的RowBounds
public class PageRowBounds extends RowBounds {
private long total;
public PageRowBounds() {
super();
}
public PageRowBounds(int offset, int limit) {
super(offset, limit);
}
public long getTotal() {
return total;
}
public void setTotal(long total) {
this.total = total;
}
}
在mybatis-confg.xml配置文件中添加插件
<plugins>
<plugin interceptor="tk.mybatis.simple.plugin.PageInterceptor">
<property name="dialect" value="tk.mybatis.simple.plugin.MySqlDialect"/>
</plugin>
</plugins>