1. tf.split
该函数主要用于对tensor进行分割,一般在设置多GPU并行计算时经常会被用到,主要是将一个batch数据集进行平分,分配给各个GPU,最后再汇总各个GPU得到的损失,从而加快模型的训练速度,其主要参数的定义如下:
- value:待分割的 `Tensor` .
- num_or_size_splits: 可以是一个整数,表示分割的后的数量,也可以是一个整数列表,表示分割后每一份的size
- axis:分割的维度,默认的第一维
import tensorflow as tf
tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name="split"
)
2. tf.add_n
该函数主要是对输入的tensor列表中每一个tensor进行加总,要求每个tensor的维度必须相同,当开启并行计算时,该函数也经常被用来计算各个GPU得到的损失,其主要参数定义如下:
- inputs:一个tensor列表
import tensorflow as tf
tf.add_n(inputs, name=None)