MyBatis简单分页插件实现

插件类,继承MyBatis的Interceptor 接口
原理分析可参考:MyBatis源码笔记(八) – 插件实现原理

@Intercepts(
      @Signature(type = Executor.class,
              method = "query",
              args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})
)
public class PageInterceptor implements Interceptor {
   private static final List<ResultMapping> EMPTYRESULTMAPPING = new
           ArrayList<ResultMapping>(0);
   private Dialect dialect;
   private Field additionalParametersField;

   @Override
   public Object intercept(Invocation invocation) throws Throwable {
       //获取拦截方法的参数
       Object[] args = invocation.getArgs();
       MappedStatement ms = (MappedStatement) args[0];
       Object parameterObject = args[1];
       RowBounds rowBounds = (RowBounds) args[2];
       //调用方法判断是否需要进行分页查询,如果不需要,直接返回结果
       if (!dialect.skip(ms.getId(), parameterObject, rowBounds)) {
           //进行分页查询
           ResultHandler resultHandler = (ResultHandler) args[3];
           //当前目标对象
           Executor executor = (Executor) invocation.getTarget();
           BoundSql boundSql = ms.getBoundSql(parameterObject);
           //反射获取动态参数
           Map<String, Object> additionalParameters = (Map<String, Object>) additionalParametersField.get(boundSql);
           //判断是否需要进行count查询
           if (dialect.beforeCount(ms.getId(), parameterObject, rowBounds)) {
               //根据当前的ms创建一个返回值为Long类型的ms
               MappedStatement countMs = newMappedStatement(ms, Long.class);
               //创建count查询缓存key
               CacheKey countKey = executor.createCacheKey(countMs, parameterObject, rowBounds, boundSql);
               //调用方言获取count sql
               String countSql = dialect.getCountSql(boundSql, parameterObject, rowBounds, countKey);
               BoundSql countBoundSql = new BoundSql(ms.getConfiguration(), countSql, boundSql.getParameterMappings(), parameterObject);
               //当使用动态SQL时,可能会产生临时参数,需要手动设置到boundsql中
               for (String key : additionalParameters.keySet()) {
                   countBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
               }
               //执行count语句
               List<Object> countResultList = executor.query(countMs, parameterObject,
                       RowBounds.DEFAULT, resultHandler, countKey, countBoundSql);
               Long count = (Long) countResultList.get(0);
               dialect.afterCount(count, parameterObject, rowBounds);
               if (count == 0L) {
                   //如果总数为0,返回空查询结果
                   return dialect.afterPage(new ArrayList(), parameterObject, rowBounds);
               }
           }
           //判断是否进行分页查询
           if (dialect.beforePage(ms.getId(), parameterObject, rowBounds)) {
               CacheKey pageKey = executor.createCacheKey(ms, parameterObject, rowBounds, boundSql);
               String pageSql = dialect.getPageSql(boundSql, parameterObject, rowBounds, pageKey);
               BoundSql pageBoundSql = new BoundSql(ms.getConfiguration(), pageSql,
                       boundSql.getParameterMappings(), parameterObject);
               //当使用动态SQL时,可能会产生临时参数,需要手动设置到boundsql中
               for (String key : additionalParameters.keySet()) {
                   pageBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
               }
               //执行分页查询
               List<Object> resultList = executor.query(ms, parameterObject,
                       RowBounds.DEFAULT, resultHandler, pageKey, pageBoundSql);
               return dialect.afterPage(resultList, parameterObject, rowBounds);
           }
       }
       //返回默认结果
       return invocation.proceed();
   }

   /**
    * 根据现有ms创建一个新的返回值类型
    *
    * @param ms
    * @param returnType
    * @return
    */
   private MappedStatement newMappedStatement(MappedStatement ms, Class<?> returnType) {
       MappedStatement.Builder builder =
               new MappedStatement.Builder(ms.getConfiguration(), ms.getId() + "_Count",
                       ms.getSqlSource(), ms.getSqlCommandType());
       builder.resource(ms.getResource());
       builder.fetchSize(ms.getFetchSize());
       builder.statementType(ms.getStatementType());
       builder.keyGenerator(ms.getKeyGenerator());
       if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
           StringBuilder keyProperties = new StringBuilder();
           for (String keyProperty : ms.getKeyProperties()) {
               keyProperties.append(keyProperties).append(",");
           }
           keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
       }
       builder.timeout(ms.getTimeout());
       builder.parameterMap(ms.getParameterMap());
       List<ResultMap> resultMaps = new ArrayList<ResultMap>();
       ResultMap resultMap = new ResultMap.Builder(ms.getConfiguration(), ms.getId(), returnType, EMPTYRESULTMAPPING).build();
       resultMaps.add(resultMap);
       builder.resultMaps(resultMaps);
       builder.resultSetType(ms.getResultSetType());
       builder.cache(ms.getCache());
       builder.flushCacheRequired(ms.isFlushCacheRequired());
       builder.useCache(ms.isUseCache());
       return builder.build();
   }

   @Override
   public Object plugin(Object target) {
       return Plugin.wrap(target, this);
   }

   @Override
   public void setProperties(Properties properties) {

       String dialectClass = properties.getProperty("dialect");
       try {
           dialect = (Dialect) Class.forName(dialectClass).newInstance();
       } catch (Exception e) {
           throw new RuntimeException("使用PageInterceptor分页插件,必须要设置dialect");
       }
       dialect.setProperties(properties);
       try {
           additionalParametersField = BoundSql.class.getDeclaredField("additionalParameters");
           additionalParametersField.setAccessible(true);
       } catch (NoSuchFieldException e) {
           throw new RuntimeException(e);
       }
   }
}

Dialect接口

public interface Dialect {

    /**
     * 跳过count和分页查询
     *
     * @param msId            执行的mybatis方法的全名
     * @param parameterObject 方法参数
     * @param rowBounds       分页参数
     * @return true 跳过,返回默认结果; false则执行分页查询
     */
    boolean skip(String msId, Object parameterObject, RowBounds rowBounds);

    /**
     * 执行分页查询前,返回true执行count会进行count查询,返回false会继续下面的beforePage判断
     *
     * @param msId
     * @param parameterObject
     * @param rowBounds
     * @return
     */
    boolean beforeCount(String msId, Object parameterObject, RowBounds rowBounds);

    /**
     * 生成count查询语句
     *
     * @param boundSql
     * @param parameterObject
     * @param rowBounds
     * @param countKey
     * @return
     */
    String getCountSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey countKey);

    /**
     * 执行完count后
     *
     * @param count
     * @param parameterObject
     * @param rowBounds
     */
    void afterCount(long count, Object parameterObject, RowBounds rowBounds);

    /**
     * 执行分页前,返回true会执行分页查询,返回false会返回默认查询结果
     *
     * @param msId
     * @param parameterObject
     * @param rowBounds
     * @return
     */
    boolean beforePage(String msId, Object parameterObject, RowBounds rowBounds);

    /**
     * 生成分页查询sql
     *
     * @param boundSql
     * @param parameterObject
     * @param rowBounds
     * @param pageKey
     * @return
     */
    String getPageSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey pageKey);

    /**
     * 分页查询后,处理分页结果,拦截器中直接return该方法的返回值
     *
     * @param pageList
     * @param parameterObject
     * @param rowBounds
     */
    Object afterPage(List pageList, Object parameterObject, RowBounds rowBounds);

    /**
     * 设置参数
     *
     * @param properties 插件属性
     */
    void setProperties(Properties properties);
}

Dialect的MySql数据库实现类

public class MySqlDialect implements Dialect {
    @Override
    public boolean skip(String msId, Object parameterObject, RowBounds rowBounds) {
        //没有分页参数,会用RowBounds.DEFAULT
        if (rowBounds == RowBounds.DEFAULT) {
            return true;
        }
        return false;
    }

    @Override
    public boolean beforeCount(String msId, Object parameterObject, RowBounds rowBounds) {
        //只有使用PageRowBounds才能记录总数
        if (rowBounds instanceof PageRowBounds) {
            return true;
        }
        return false;
    }

    @Override
    public String getCountSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey countKey) {
        //简单嵌套Mysql count查询
        return "select count(*) from (" + boundSql.getSql() + ") temp";
    }

    @Override
    public void afterCount(long count, Object parameterObject, RowBounds rowBounds) {
        //设置总数
        ((PageRowBounds) rowBounds).setTotal(count);
    }

    @Override
    public boolean beforePage(String msId, Object parameterObject, RowBounds rowBounds) {

        if (rowBounds != RowBounds.DEFAULT) {
            return true;
        }
        return false;
    }

    @Override
    public String getPageSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey pageKey) {
        pageKey.update("RowBounds");
        return boundSql.getSql() + " limit " + rowBounds.getOffset() + "," + rowBounds.getLimit();
    }

    @Override
    public Object afterPage(List pageList, Object parameterObject, RowBounds rowBounds) {
        //直接返回结果
        return pageList;
    }

    @Override
    public void setProperties(Properties properties) {
        //没有其他参数
    }
}

分页查询带总数的实体类,继承MyBatis的RowBounds

public class PageRowBounds extends RowBounds {

    private long total;

    public PageRowBounds() {
        super();
    }

    public PageRowBounds(int offset, int limit) {
        super(offset, limit);
    }

    public long getTotal() {
        return total;
    }

    public void setTotal(long total) {
        this.total = total;
    }
}

在mybatis-confg.xml配置文件中添加插件

	<plugins>
        <plugin interceptor="tk.mybatis.simple.plugin.PageInterceptor">
            <property name="dialect" value="tk.mybatis.simple.plugin.MySqlDialect"/>
        </plugin>
    </plugins>

猜你喜欢

转载自blog.csdn.net/seasonLai/article/details/82964166