sharding-jdbc源码解析之sql执行

sql执行源码解析

找到这个方法

com.dangdang.ddframe.rdb.sharding.jdbc.core.statement.ShardingPreparedStatement#execute

@Override
    public boolean execute() throws SQLException {
        try {
            Collection<PreparedStatementUnit> preparedStatementUnits = route();
//            创建预编译statement的sql执行器
            return new PreparedStatementExecutor(
                    getShardingConnection().getShardingContext().getExecutorEngine(), getRouteResult().getSqlStatement().getType(), preparedStatementUnits, getParameters()).execute();
        } finally {
//            释放内存
            clearBatch();
        }
    }

组装预编译对象执行单元集合

/**
 * 预编译语句对象执行单元.
 *
 * @author zhangliang
 */
@RequiredArgsConstructor
@Getter
public final class PreparedStatementUnit implements BaseStatementUnit {

//    sql执行单元
    private final SQLExecutionUnit sqlExecutionUnit;

//    预编译对象
    private final PreparedStatement statement;
}
进入到route()方法
private Collection<PreparedStatementUnit> route() throws SQLException {
        Collection<PreparedStatementUnit> result = new LinkedList<>();
//        执行sql路由逻辑并得到路由结果并装载支持静态分片的预编译statement对象
        setRouteResult(routingEngine.route(getParameters()));
//        遍历最小sql执行单元
        for (SQLExecutionUnit each : getRouteResult().getExecutionUnits()) {
//            获取sql类型
            SQLType sqlType = getRouteResult().getSqlStatement().getType();
            Collection<PreparedStatement> preparedStatements;
            if (SQLType.DDL == sqlType) {
//                如果是DDL,创建DDL的prepareStatement对象
                preparedStatements = generatePreparedStatementForDDL(each);
            } else {
//                DDL之外的语句创建prepareStatement对象
                preparedStatements = Collections.singletonList(generatePreparedStatement(each));
            }
//            装载路由的statement对象
            getRoutedStatements().addAll(preparedStatements);
            for (PreparedStatement preparedStatement : preparedStatements) {
                replaySetParameter(preparedStatement);
                result.add(new PreparedStatementUnit(each, preparedStatement));
            }
        }
        return result;
    }
//        遍历最小sql执行单元
        for (SQLExecutionUnit each : getRouteResult().getExecutionUnits()) {
//            获取sql类型
            SQLType sqlType = getRouteResult().getSqlStatement().getType();
//                如果是DDL,创建DDL的prepareStatement对象
                preparedStatements = generatePreparedStatementForDDL(each);
private Collection<PreparedStatement> generatePreparedStatementForDDL(final SQLExecutionUnit sqlExecutionUnit) throws SQLException {
        Collection<PreparedStatement> result = new LinkedList<>();
//        获取可以执行DDL语句的数据库连接对象集合
        Collection<Connection> connections = getShardingConnection().getConnectionForDDL(sqlExecutionUnit.getDataSource());
        for (Connection each : connections) {
//            创建prepareStatement对象
            result.add(each.prepareStatement(sqlExecutionUnit.getSql(), getResultSetType(), getResultSetConcurrency(), getResultSetHoldability()));
        }
        return result;
    }
//        获取可以执行DDL语句的数据库连接对象集合
        Collection<Connection> connections = getShardingConnection().getConnectionForDDL(sqlExecutionUnit.getDataSource());
public Collection<Connection> getConnectionForDDL(final String dataSourceName) throws SQLException {
        final Context metricsContext = MetricsContext.start(Joiner.on("-").join("ShardingConnection-getConnectionForDDL", dataSourceName));
//        从分片规则的数据库分片规则中获取数据源
        DataSource dataSource = shardingContext.getShardingRule().getDataSourceRule().getDataSource(dataSourceName);
        Preconditions.checkState(null != dataSource, "Missing the rule of %s in DataSourceRule", dataSourceName);
        Collection<DataSource> dataSources = new LinkedList<>();
        if (dataSource instanceof MasterSlaveDataSource) {
            dataSources.add(((MasterSlaveDataSource) dataSource).getMasterDataSource());
            dataSources.addAll(((MasterSlaveDataSource) dataSource).getSlaveDataSources());
        } else {
            dataSources.add(dataSource);
        }
        Collection<Connection> result = new LinkedList<>();
        for (DataSource each : dataSources) {
//            根据数据源获取数据库连接
            Connection connection = each.getConnection();
            replayMethodsInvocation(connection);//重新调用调用过的方法动作
            result.add(connection);
        }
        MetricsContext.stop(metricsContext);
        return result;
    }
向上返回到这里
private Collection<PreparedStatementUnit> route() throws SQLException {
        Collection<PreparedStatementUnit> result = new LinkedList<>();
//        执行sql路由逻辑并得到路由结果并装载支持静态分片的预编译statement对象
        setRouteResult(routingEngine.route(getParameters()));
//        遍历最小sql执行单元
        for (SQLExecutionUnit each : getRouteResult().getExecutionUnits()) {
//            获取sql类型
            SQLType sqlType = getRouteResult().getSqlStatement().getType();
            Collection<PreparedStatement> preparedStatements;
            if (SQLType.DDL == sqlType) {
//                如果是DDL,创建DDL的prepareStatement对象
                preparedStatements = generatePreparedStatementForDDL(each);
            } else {
//                DDL之外的语句创建prepareStatement对象
                preparedStatements = Collections.singletonList(generatePreparedStatement(each));
            }
//            装载路由的statement对象
            getRoutedStatements().addAll(preparedStatements);
            for (PreparedStatement preparedStatement : preparedStatements) {
                replaySetParameter(preparedStatement);
                result.add(new PreparedStatementUnit(each, preparedStatement));
            }
        }
        return result;
    }
} else {
//                DDL之外的语句创建prepareStatement对象
                preparedStatements = Collections.singletonList(generatePreparedStatement(each));
            }
private PreparedStatement generatePreparedStatement(final SQLExecutionUnit sqlExecutionUnit) throws SQLException {
        Optional<GeneratedKey> generatedKey = getGeneratedKey();
//        获取数据库连接
        Connection connection = getShardingConnection().getConnection(sqlExecutionUnit.getDataSource(), getRouteResult().getSqlStatement().getType());
//        创建prepareStatement对象
        if (isReturnGeneratedKeys() || isReturnGeneratedKeys() && generatedKey.isPresent()) {
            return connection.prepareStatement(sqlExecutionUnit.getSql(), RETURN_GENERATED_KEYS);
        }
        return connection.prepareStatement(sqlExecutionUnit.getSql(), getResultSetType(), getResultSetConcurrency(), getResultSetHoldability());
    }

获取数据库连接对象

//        获取数据库连接
        Connection connection = getShardingConnection().getConnection(sqlExecutionUnit.getDataSource(), getRouteResult().getSqlStatement().getType());
public Connection getConnection(final String dataSourceName, final SQLType sqlType) throws SQLException {
//        从缓存中获取数据源连接
        Optional<Connection> connection = getCachedConnection(dataSourceName, sqlType);
        if (connection.isPresent()) {
            return connection.get();
        }
        Context metricsContext = MetricsContext.start(Joiner.on("-").join("ShardingConnection-getConnection", dataSourceName));
//        根据数据源名称获取数据源对象
        DataSource dataSource = shardingContext.getShardingRule().getDataSourceRule().getDataSource(dataSourceName);
        Preconditions.checkState(null != dataSource, "Missing the rule of %s in DataSourceRule", dataSourceName);
        String realDataSourceName;
        if (dataSource instanceof MasterSlaveDataSource) {
            dataSource = ((MasterSlaveDataSource) dataSource).getDataSource(sqlType);
            realDataSourceName = MasterSlaveDataSource.getDataSourceName(dataSourceName, sqlType);
        } else {
            realDataSourceName = dataSourceName;
        }
        Connection result = dataSource.getConnection();
        MetricsContext.stop(metricsContext);
        connectionMap.put(realDataSourceName, result);
        replayMethodsInvocation(result);
        return result;
    }

向上返回到这里

private Collection<PreparedStatementUnit> route() throws SQLException {
        Collection<PreparedStatementUnit> result = new LinkedList<>();
//        执行sql路由逻辑并得到路由结果并装载支持静态分片的预编译statement对象
        setRouteResult(routingEngine.route(getParameters()));
//        遍历最小sql执行单元
        for (SQLExecutionUnit each : getRouteResult().getExecutionUnits()) {
//            获取sql类型
            SQLType sqlType = getRouteResult().getSqlStatement().getType();
            Collection<PreparedStatement> preparedStatements;
            if (SQLType.DDL == sqlType) {
//                如果是DDL,创建DDL的prepareStatement对象
                preparedStatements = generatePreparedStatementForDDL(each);
            } else {
//                DDL之外的语句创建prepareStatement对象
                preparedStatements = Collections.singletonList(generatePreparedStatement(each));
            }
            getRoutedStatements().addAll(preparedStatements);
            for (PreparedStatement preparedStatement : preparedStatements) {
                replaySetParameter(preparedStatement);
                result.add(new PreparedStatementUnit(each, preparedStatement));
            }
        }
        return result;
    }
//            装载路由的statement对象
            getRoutedStatements().addAll(preparedStatements);

向上返回到这里

@Override
    public boolean execute() throws SQLException {
        try {
            Collection<PreparedStatementUnit> preparedStatementUnits = route();
//            创建预编译statement的sql执行器
            return new PreparedStatementExecutor(
                    getShardingConnection().getShardingContext().getExecutorEngine(), getRouteResult().getSqlStatement().getType(), preparedStatementUnits, getParameters()).execute();
        } finally {
//            释放内存
            clearBatch();
        }
    }

进入到这个方法

com.dangdang.ddframe.rdb.sharding.executor.type.prepared.PreparedStatementExecutor#execute 执行sql请求

/**
 * 执行SQL请求.
 * 
 * @return true表示执行DQL, false表示执行的DML
 */
public boolean execute() {
    Context context = MetricsContext.start("ShardingPreparedStatement-execute");
    try {
        List<Boolean> result = executorEngine.executePreparedStatement(sqlType, preparedStatementUnits, parameters, new ExecuteCallback<Boolean>() {
            
            @Override
            public Boolean execute(final BaseStatementUnit baseStatementUnit) throws Exception {
                return ((PreparedStatement) baseStatementUnit.getStatement()).execute();
            }
        });
        if (null == result || result.isEmpty() || null == result.get(0)) {
            return false;
        }
        return result.get(0);
    } finally {
        MetricsContext.stop(context);
    }
}

进入到执行prepareStatement对象的方法

com.dangdang.ddframe.rdb.sharding.executor.ExecutorEngine#executePreparedStatement

/**
 * 执行PreparedStatement.
 *
 * @param sqlType SQL类型
 * @param preparedStatementUnits 语句对象执行单元集合
 * @param parameters 参数列表
 * @param executeCallback 执行回调函数
 * @param <T> 返回值类型
 * @return 执行结果
 */
public <T> List<T> executePreparedStatement(
        final SQLType sqlType, final Collection<PreparedStatementUnit> preparedStatementUnits, final List<Object> parameters, final ExecuteCallback<T> executeCallback) {
    return execute(sqlType, preparedStatementUnits, Collections.singletonList(parameters), executeCallback);
}

com.dangdang.ddframe.rdb.sharding.executor.ExecutorEngine#execute

进入sql执行引擎的这个方法

private  <T> List<T> execute(
            final SQLType sqlType, final Collection<? extends BaseStatementUnit> baseStatementUnits, final List<List<Object>> parameterSets, final ExecuteCallback<T> executeCallback) {
        if (baseStatementUnits.isEmpty()) {
            return Collections.emptyList();
        }
        Iterator<? extends BaseStatementUnit> iterator = baseStatementUnits.iterator();
//        获得一个sql语句执行单元
        BaseStatementUnit firstInput = iterator.next();
//        异步多线程去执行->
        ListenableFuture<List<T>> restFutures = asyncExecute(sqlType, Lists.newArrayList(iterator), parameterSets, executeCallback);
        T firstOutput;
        List<T> restOutputs;
        try {
//            同步执行->
            firstOutput = syncExecute(sqlType, firstInput, parameterSets, executeCallback);
//            获取执行结果
            restOutputs = restFutures.get();
            //CHECKSTYLE:OFF
        } catch (final Exception ex) {
            //CHECKSTYLE:ON
            ExecutorExceptionHandler.handleException(ex);
            return null;
        }
        List<T> result = Lists.newLinkedList(restOutputs);
        result.add(0, firstOutput);
        return result;
    }
//        异步多线程去执行->
        ListenableFuture<List<T>> restFutures = asyncExecute(sqlType, Lists.newArrayList(iterator), parameterSets, executeCallback);

进入到这个方法

com.dangdang.ddframe.rdb.sharding.executor.ExecutorEngine#asyncExecute

private <T> ListenableFuture<List<T>> asyncExecute(
            final SQLType sqlType, final Collection<BaseStatementUnit> baseStatementUnits, final List<List<Object>> parameterSets, final ExecuteCallback<T> executeCallback) {
        List<ListenableFuture<T>> result = new ArrayList<>(baseStatementUnits.size());
//        是否有异常出现
        final boolean isExceptionThrown = ExecutorExceptionHandler.isExceptionThrown();
//        执行数据是多线程安全的
        final Map<String, Object> dataMap = ExecutorDataMap.getDataMap();
        for (final BaseStatementUnit each : baseStatementUnits) {
//            线程分发执行
            result.add(executorService.submit(new Callable<T>() {
                
                @Override
                public T call() throws Exception {
                    return executeInternal(sqlType, each, parameterSets, executeCallback, isExceptionThrown, dataMap);
                }
            }));
        }

        return Futures.allAsList(result);
    }

进入这个方法

return executeInternal(sqlType, each, parameterSets, executeCallback, isExceptionThrown, dataMap);
private <T> T executeInternal(final SQLType sqlType, final BaseStatementUnit baseStatementUnit, final List<List<Object>> parameterSets, final ExecuteCallback<T> executeCallback, 
                          final boolean isExceptionThrown, final Map<String, Object> dataMap) throws Exception {
//        同一个数据源是串行执行的
        synchronized (baseStatementUnit.getStatement().getConnection()) {
            T result;
            ExecutorExceptionHandler.setExceptionThrown(isExceptionThrown);
            ExecutorDataMap.setDataMap(dataMap);
            List<AbstractExecutionEvent> events = new LinkedList<>();
            if (parameterSets.isEmpty()) {
//                添加执行事件-》
                events.add(getExecutionEvent(sqlType, baseStatementUnit, Collections.emptyList()));
            }
            for (List<Object> each : parameterSets) {
//                添加执行事件
                events.add(getExecutionEvent(sqlType, baseStatementUnit, each));
            }
            for (AbstractExecutionEvent event : events) {
//                这里是事件总线实现,发布事件
                EventBusInstance.getInstance().post(event);
            }
            try {
//                回调函数获取回调结果
                result = executeCallback.execute(baseStatementUnit);
            } catch (final SQLException ex) {
//                执行失败,更新事件,发布执行失败的事件
                for (AbstractExecutionEvent each : events) {
                    each.setEventExecutionType(EventExecutionType.EXECUTE_FAILURE);
                    each.setException(Optional.of(ex));
                    EventBusInstance.getInstance().post(each);
                    ExecutorExceptionHandler.handleException(ex);
                }
                return null;
            }
            for (AbstractExecutionEvent each : events) {
//                执行成功,更新事件内容,发布执行成功事件
                each.setEventExecutionType(EventExecutionType.EXECUTE_SUCCESS);
                EventBusInstance.getInstance().post(each);
            }
            return result;
        }
    }

向上返回到这个方法

com.dangdang.ddframe.rdb.sharding.executor.ExecutorEngine#execute

//            同步执行->
            firstOutput = syncExecute(sqlType, firstInput, parameterSets, executeCallback);

sql批量执行源码解析

进入到这个方法

com.dangdang.ddframe.rdb.sharding.jdbc.core.statement.ShardingPreparedStatement#executeBatch

@Override
    public int[] executeBatch() throws SQLException {
        try {
            return new BatchPreparedStatementExecutor(
//                    创建批量statement执行器并执行批量sql
                    getShardingConnection().getShardingContext().getExecutorEngine(), getRouteResult().getSqlStatement().getType(), batchStatementUnits, parameterSets).executeBatch();
        } finally {
//            释放内存
            clearBatch();
        }
    }
/**
 * 执行批量SQL.
 * 
 * @return 执行结果
 */
public int[] executeBatch() {
    Context context = MetricsContext.start("ShardingPreparedStatement-executeBatch");
    try {
        return accumulate(executorEngine.executeBatch(sqlType, batchPreparedStatementUnits, parameterSets, new ExecuteCallback<int[]>() {
            
            @Override
            public int[] execute(final BaseStatementUnit baseStatementUnit) throws Exception {
                return baseStatementUnit.getStatement().executeBatch();
            }
        }));
    } finally {
        MetricsContext.stop(context);
    }
}
/**
 * 执行Batch.
 *
 * @param sqlType SQL类型
 * @param batchPreparedStatementUnits 语句对象执行单元集合
 * @param parameterSets 参数列表集
 * @param executeCallback 执行回调函数
 * @return 执行结果
 */
public List<int[]> executeBatch(
        final SQLType sqlType, final Collection<BatchPreparedStatementUnit> batchPreparedStatementUnits, final List<List<Object>> parameterSets, final ExecuteCallback<int[]> executeCallback) {
    return execute(sqlType, batchPreparedStatementUnits, parameterSets, executeCallback);
}
/**
 * 预编译语句对象的执行上下文.
 * 
 * @author zhangliang
 */
@RequiredArgsConstructor
@Getter
public final class BatchPreparedStatementUnit implements BaseStatementUnit {
    
//    sql最小执行单元
    private final SQLExecutionUnit sqlExecutionUnit;
    
//    预编译statement对象
    private final PreparedStatement statement;

最后调用这个方法

private  <T> List<T> execute(
            final SQLType sqlType, final Collection<? extends BaseStatementUnit> baseStatementUnits, final List<List<Object>> parameterSets, final ExecuteCallback<T> executeCallback) {
        if (baseStatementUnits.isEmpty()) {
            return Collections.emptyList();
        }
        Iterator<? extends BaseStatementUnit> iterator = baseStatementUnits.iterator();
//        获得一个sql语句执行单元
        BaseStatementUnit firstInput = iterator.next();
//        异步多线程去执行->
        ListenableFuture<List<T>> restFutures = asyncExecute(sqlType, Lists.newArrayList(iterator), parameterSets, executeCallback);
        T firstOutput;
        List<T> restOutputs;
        try {
//            同步执行->
            firstOutput = syncExecute(sqlType, firstInput, parameterSets, executeCallback);
//            获取执行结果
            restOutputs = restFutures.get();
            //CHECKSTYLE:OFF
        } catch (final Exception ex) {
            //CHECKSTYLE:ON
            ExecutorExceptionHandler.handleException(ex);
            return null;
        }
        List<T> result = Lists.newLinkedList(restOutputs);
        result.add(0, firstOutput);
        return result;
    }

以上是sql执行逻辑的源码解析。

说到最后

以上内容,仅供参考。

猜你喜欢

转载自my.oschina.net/u/3775437/blog/1785125