使用Interceptor来做分页,借鉴别人的项目,但是找不到是借鉴谁的了,网上有很多的案例,如果有错误请指正,O(∩_∩)O谢谢。
大致思路:根据@Intercepts的配置信息(方法名,参数等)动态判断要拦截的方法.获取到参数再根据自己的需求封装成Invocation,并调用Interceptor的proceed方法.Executor的执行大概是这样的流程:拦截器代理类对象->拦截器->目标方法
Executor.Method->Plugin.invoke->Interceptor.intercept->Invocation.proceed->method.invoke
在mybaits的配置文件中添加拦截器
<plugins> <plugin interceptor="com.java.MybatisPageableInterceptor"> </plugin> </plugins>
@Intercepts({@Signature(type = Executor.class, method = "query", args =
{MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})})
public class MybatisPageableInterceptor implements Interceptor {
static int MAPPED_STATEMENT_INDEX = 0;
static int PARAMETER_INDEX = 1;
static int ROWBOUNDS_INDEX = 2;
int page = 2;
int pagesize = 10;
String orderBY = "order by id ";
public Object intercept(Invocation inv) throws Throwable {
//inv 有三个参数:target调用的对象,method调用的方法,args参数。与@Intercepts对应起来
final Object[] queryArgs = inv.getArgs();
final MappedStatement ms = (MappedStatement) queryArgs[MAPPED_STATEMENT_INDEX];
// 查找方法参数中的 分页请求对象
//sql请求的参数,可能是一个对象
final Object parameter = queryArgs[PARAMETER_INDEX];
//获取动态sql,BoundSql
final BoundSql boundSql = ms.getBoundSql(parameter);
// 删除尾部的 ';'
String sql = boundSql.getSql().trim().replaceAll(";$", "");
// 1. 搞定总记录数(如果需要的话)
int total = this.queryTotal(sql, ms, boundSql);
// 2. 搞定limit 查询
// 2.1 获取分页SQL,并完成参数准备,给sql加上limit
String limitSql = MySQLDialect.getLimitString(sql, (page - 1) * size, pagesize, orderBY);
//获取新的RowBounds
queryArgs[ROWBOUNDS_INDEX] = new RowBounds(RowBounds.NO_ROW_OFFSET, RowBounds.NO_ROW_LIMIT);
//获取新的MappedStatement
queryArgs[MAPPED_STATEMENT_INDEX] = copyFromNewSql(ms, boundSql, limitSql);
// 2.2 继续执行剩余步骤,获取查询结果
Object ret = inv.proceed();
// 3. 组成分页对象
Pager<?> pi = new Pager<>(page, pagesize, total, (List<Object>) ret);
// 4. MyBatis 需要返回一个List对象,这里只是满足MyBatis而作的临时包装
List<Pager<?>> tmp = new ArrayList(1);
tmp.add(pi);
return tmp;
}
/**
* 查询总记录数使用sql的count方法,使用的是常见的jdbc的方式执行sql获取总数
* @param sql
* @param mappedStatement
* @param boundSql
* @return
* @throws SQLException
*/
private int queryTotal(String sql, MappedStatement mappedStatement,
BoundSql boundSql) throws SQLException {
Connection connection = null;
PreparedStatement countStmt = null;
ResultSet rs = null;
try {
connection = mappedStatement.getConfiguration().getEnvironment().getDataSource().getConnection();String countSql = "select count(1) from (" + sql + ") tmp_count";
countStmt = connection.prepareStatement(countSql);
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql,
boundSql.getParameterMappings(), boundSql.getParameterObject());
setParameters(countStmt, mappedStatement, countBoundSql, boundSql.getParameterObject());
rs = countStmt.executeQuery();
int totalCount = 0;
if (rs.next()) {
totalCount = rs.getInt(1);
}
return totalCount;
} catch (SQLException e) {
throw e;
} finally {
if (rs != null) {
rs.close();
}
if (countStmt != null) {
countStmt.close();
}
if (connection != null) {
connection.close();
}
BoundSql boundSql, String sql) {
BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), sql, boundSql.getParameterMappings(), boundSql.getParameterObject());
for (ParameterMapping mapping : boundSql.getParameterMappings()) {
String prop = mapping.getProperty();
if (boundSql.hasAdditionalParameter(prop)) {
newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
}
}
return copyFromMappedStatement(ms, new BoundSqlSqlSource(newBoundSql));
}
public static class BoundSqlSqlSource implements SqlSource {
BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
//重新生成mappedStatement
private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
}
//给sql加上limit进行sql的分页拼装
public static class MySQLDialect {final static String LIMIT_SQL_PATTERN = "%s limit %s, %s";
final static String LIMIT_SQL_PATTERN_FIRST = " %s limit %s";
public static String getLimitString(String sql, int offset, int limit, String orderBy,String tableName, String tableIdName) {
if (orderBy != null && !"".equals(orderBy)) {
sql += " order by " + orderBy;
}
if (offset == 0) {
return String.format(LIMIT_SQL_PATTERN_FIRST, sql, limit);
}
return String.format(LIMIT_SQL_PATTERN, sql, offset, limit);
}