tensorflow相关tensor计算函数

1. tf.split

该函数主要用于对tensor进行分割,一般在设置多GPU并行计算时经常会被用到,主要是将一个batch数据集进行平分,分配给各个GPU,最后再汇总各个GPU得到的损失,从而加快模型的训练速度,其主要参数的定义如下:

  • value:待分割的 `Tensor` .
  • num_or_size_splits: 可以是一个整数,表示分割的后的数量,也可以是一个整数列表,表示分割后每一份的size
  • axis:分割的维度,默认的第一维 
import tensorflow as tf

tf.split(
    value, 
    num_or_size_splits, 
    axis=0, 
    num=None, 
    name="split"
)

2. tf.add_n

该函数主要是对输入的tensor列表中每一个tensor进行加总,要求每个tensor的维度必须相同,当开启并行计算时,该函数也经常被用来计算各个GPU得到的损失,其主要参数定义如下:

  • inputs:一个tensor列表
import tensorflow as tf

tf.add_n(inputs, name=None)

猜你喜欢

转载自blog.csdn.net/linchuhai/article/details/84892295