图解TensorFlow中Tensor的shape概念与tf op: tf.reshape

田海立@CSDN 2020-10-18

图解NCHW与NHWC数据格式》中从逻辑表达和物理存储角度用图的方式讲述了NHWC与NCHW两种数据格式,数据shape是可以改变的,本文介绍TensorFlow里Tensor的Shape概念,并用图示和程序阐述了reshape运算。

一、TensorFlow中Tensor的Shape

TensorFlow中的数据都是由Tensor来表示,Shape相关有下列一些概念:

  • Rank:维数
  • Dimension:表达每一维长度
  • Size:所有的Dimension数值相乘,也就是Tensor里数据元素的尺寸了

rank为0/1/2的典型Tensor如下图所示:

Tensor rank为3时,数据表达为:

二、Tensor的逻辑表达与物理存储

如《图解NCHW与NHWC数据格式》中所述,数据可以从逻辑上和物理排布上去理解。而本文第一节中你可以仍从逻辑上去理解,还未牵涉到物理存储数据排布。

三维以下的比较容易理解,各个ML框架之间也没大的区别,对于三维(及以上)Tensor的排布就很不同了,这里着重介绍3-D。

我们已经知道TensorFlow的Tensor缺省是NHWC的,对于上面的shape(3, 2, 5)的Tensor【n为1】,在TensorFlow中应该是这样的:

如果数据值按顺序排布如下,

      [[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9]],

       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]]]

那么对应上面三维立方体的摆放应该如下:

三、tf.reshape()运算

reshape原型如下:

tf.reshape(
    tensor, shape, name=None
)

3.1 tf.reshape()的不改变性

  1. tf.reshape()运算不改变数据的物理排布,也就是说一个Tensor reshape到别的shape只是逻辑上shape的改变,存储的数据不会改变;
  2. tf.reshape()也就不会改变Tensor的size。指定的新的shape的size如果与原Tensor的size不一致就会报错。比如上面Shape(3, 2, 5)的Tensor就没法reshape成7x?。

有了上面两个原则,tf.reshape()运算就很容易理解了,物理存储不变,就看rank以及各个dimension怎么取了。

3.2 tf.reshape() 图示

比如,上面Tensor有30个数:从0~29顺序存储。可以存储为(3, 2, 5)【上面介绍过的3-D】,也可以存储为2D的(3, 10)或(6, 5),也可以存储为1-D的,全部展开。

3.3 程序实现如下:

TF2.0以后的版本上,直接可以执行,而不用还要在session下执行。当然前提是已经

import tensorflow as tf

1. 30个数的数据

>>> t = tf.range(30)
>>> t
<tf.Tensor: shape=(30,), dtype=int32, numpy=
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=int32)>
>>> 

2. shape(3, 2, 5)

>>> t = tf.reshape(t, [3,2,5])
>>> t
<tf.Tensor: shape=(3, 2, 5), dtype=int32, numpy=
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9]],

       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]]], dtype=int32)>
>>> 

3. shape(3, 10)

>>> t = tf.reshape(t, [3, 10])
>>> t
<tf.Tensor: shape=(3, 10), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]], dtype=int32)>
>>> 

4. shape(6, 5)

>>> t = tf.reshape(t, [6, 5])
>>> t
<tf.Tensor: shape=(6, 5), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29]], dtype=int32)>
>>> 

四、小结

本文介绍了TensorFlow里Tensor的Shape概念,并用图示和实际程序解释了reshape的变化。

猜你喜欢

转载自blog.csdn.net/thl789/article/details/109139190
今日推荐