Tensorflow 维度变换

1、维度变换介绍
在神经网络运算过程中,维度变换是最核心的张量操作,通过维度变换可以将数据任意地切换形式,满足不同的运算需求。算法的每个模块对于数据张量的格式有不同的逻辑要求,当现有的数据格式不满足算法要求时,需要通过维度变换将数据调整为正确的格式,这就是维度变换的功能。
基本的维度变换操作函数包含改变视图reshape、插入新维度expand_dims、删除维度squeeze、交换维度transpose、复制数据tile等函数。
2、张量与标量
例如:大小shape为[4,32,32,3]的张量,可以理解为有4张图片,每张图片32行32列,每个位置有RGB 3个通道数据,张量的存储体现是在内存中保存的一段连续的内存区域,对于同样的存储,可以有不同的理解方式,例如上述张量,可以在不改变储存,张量可以理解为4张样本,每个样本的特征长度3072的向量。同一个储存,从不同的角度分析数据,可以产生不同的视图,视图是非常灵活的,但也需要符合常理。

import tensorflow as tf
#产生向量
x = tf.range(96)
#改变x的视图,获得四维张量,存储没有改变
x = tf.reshape(x,[2,4,4,3])
print(x)

3、视图变换
为了表达方便,把张量shape列表中相对靠左侧的维度称为大维度,shape列表中相对靠右侧的维度称为小维度,例如[2,4,4,3]的张量中,图片数量维度与通道数量相比,图片数量称为大维度,通道数称为小维度。视图变换只需要满足新视图的元素总量与存储区域大小相等,视图的元素数量等于bhw*c。

张量的初始化视图为[b,h,w,c],写入内存布局,现在改变张量:
(1)、[b,hw,c]:b张图片,hw个像素点,c个通道;
(2)、[b,h,wc]:b张图片,h行,每行有wc个特征;
(3)、[b,c,h,w]:b张图片,c个通道,h行,w,列,改变初始的顺序。

改变视图是神经网络中非常常见的操作,可以通过串联多个reshape操作来实现复杂逻辑,但是在通过reshape改变视图时,必须始终记住张量的存储顺序,新视图的维度顺序不能与存储顺序相悖,否则需要通过交换维度操作将存储顺序同步过来。

#维度数和形状列表
x.ndim,x.shape
#(4, TensorShape([2, 4, 4, 3]))
#参数-1表示当前轴上长度需要根据张量总元素不变的法则自动推导
tf.reshape(x,[2,-1])
'''
<tf.Tensor: shape=(2, 48), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
        16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
        32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
       [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
        64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
        80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95]])>
'''
tf.reshape(x,[2,4,12])
'''
<tf.Tensor: shape=(2, 4, 12), dtype=int32, numpy=
array([[[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
        [36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]],

       [[48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
        [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71],
        [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83],
        [84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95]]])>
'''
tf.reshape(x,[2,16,3])
'''
<tf.Tensor: shape=(2, 16, 3), dtype=int32, numpy=
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17],
        [18, 19, 20],
        [21, 22, 23],
        [24, 25, 26],
        [27, 28, 29],
        [30, 31, 32],
        [33, 34, 35],
        [36, 37, 38],
        [39, 40, 41],
        [42, 43, 44],
        [45, 46, 47]],

       [[48, 49, 50],
        [51, 52, 53],
        [54, 55, 56],
        [57, 58, 59],
        [60, 61, 62],
        [63, 64, 65],
        [66, 67, 68],
        [69, 70, 71],
        [72, 73, 74],
        [75, 76, 77],
        [78, 79, 80],
        [81, 82, 83],
        [84, 85, 86],
        [87, 88, 89],
        [90, 91, 92],
        [93, 94, 95]]])>
'''

4、增删维度
一张28×28大小的灰度图片的数据保存为shape为[28,28]的张量,在末尾给张量增加一新维度,定义为通道数维度,此时张量的shape变为[28,28,1]:

x = tf.random.uniform((28,28),maxval=10,dtype=tf.int32)
print('初始维度:',x.shape)
x = tf.expand_dims(x,axis=2)
print('变换后维度:',x.shape)
'''
初始维度: (28, 28)
变换后维度: (28, 28, 1)
'''

可以在最前面插入一个新的维度:

x = tf.expand_dims(x,axis=0)
x.shappe
#TensorShape([1, 28, 28, 1])
x = tf.expand_dims(x,axis=-1)
x.shape
#TensorShape([1, 28, 28, 1, 1])
x = tf.expand_dims(x,axis=-4)
x.shape
#TensorShape([1, 28, 1, 28, 1, 1])

tf.expand_dims的axis为正时,表示在当前维度之前插入一个新维度;为负时,表示当前维度之后插入一个新的维度。
删除维度只能删除长度为1的维度,也不会改变张量的存储。

x = tf.squeeze(x,axis=0)
x.shape
#TensorShape([28, 1, 28, 1, 1])
x = tf.squeeze(x,axis=1)
x.shape
#TensorShape([28,28, 1, 1])
#默认删除所有长度为1的维度
x =tf.squeeze(x)
x.shape
#TensorShape([28, 28])

4、交换维度
交换维度操作,改变了张量的存储顺序,同时也改变了张量的视图:

#[b,h,w,c]到[b,c,h,w]维度交换运算
x = tf.random.uniform((4,28,28,3),maxval=10,dtype=tf.int32)
print('初始维度:',x.shape)
x = tf.transpose(x,perm=(0,3,1,2))
print('变换后维度:',x.shape)
'''
初始维度: (4, 28, 28, 3)
变换后维度: (4, 3, 28, 28)
'''
#[b,h,w,c]交换为[b,w,h,c]
x = tf.random.uniform((4,28,28,3),maxval=10,dtype=tf.int32)
print('初始维度:',x.shape)
x = tf.transpose(x,perm=(0,2,1,3))
print('变换后维度:',x.shape)
'''
初始维度: (4, 28, 28, 3)
变换后维度: (4, 28, 28, 3)
'''

5、复制数据
当通过增加维度操作插入新维度后,可能希望在新的维度上面复制若干份数据:

x = tf.constant([1,2])
print(x.shape)
print(x)
x = tf.expand_dims(x,axis=0)
print(x.shape)
print(x)
x = tf.tile(x,multiples=[2,1])
print(x.shape)
print(x)
'''
(2,)
tf.Tensor([1 2], shape=(2,), dtype=int32)
(1, 2)
tf.Tensor([[1 2]], shape=(1, 2), dtype=int32)
(2, 2)
tf.Tensor(
[[1 2]
 [1 2]], shape=(2, 2), dtype=int32)
'''
x = tf.range(4)
print(x.shape)
print(x)
x = tf.reshape(x,(2,2))
print(x.shape)
print(x)
#在列维度复制1份数据
x = tf.tile(x,multiples=[1,2])
print(x.shape)
print(x)
#在行维度复制1份数据
x = tf.tile(x,multiples=[2,1])
print(x.shape)
print(x)
'''
(4,)
tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
(2, 2)
tf.Tensor(
[[0 1]
 [2 3]], shape=(2, 2), dtype=int32)
 (2, 4)
tf.Tensor(
[[0 1 0 1]
 [2 3 2 3]], shape=(2, 4), dtype=int32)
(4, 4)
tf.Tensor(
[[0 1 0 1]
 [2 3 2 3]
 [0 1 0 1]
 [2 3 2 3]], shape=(4, 4), dtype=int32)
'''

Broadcasting机制
Broadcasting机制的核心思想是普适性,即同一份数据能普遍适合于其他位置。在验证普适性之前,需要先将张量shape靠右对齐,然后进行普适性判断:对于长度为1的维度,默认这个数据普遍适合于当前维度的其他位置;对于不存在的维度,则在增加新维度后默认当前数据也是普适于新维度的,从而可以扩展为更多维度数、任意长度的张量形状。
例如:shape为[w,1]的张量,需要扩展为shape:[b,h,w,c],首先将2个shape靠右对齐,对于通道维度c,张量的长度为1,则默认此数据同样适合当前维度的其他位置,将数据在逻辑上复制c-1份,长度变为c;对于不存在的b和h维度,则自动插入新维度,新维度长度为1,同时默认当前的数据普适于新维度的其他位置,然后将数据b和h维度的长度自动扩展为b和h。

x = tf.random.normal([32,1])
tf.broadcast_to(x,[3,32,32,3])
'''
<tf.Tensor: shape=(3, 32, 32, 3), dtype=float32, numpy=
array([[[[-0.03420183, -0.03420183, -0.03420183],
         [-0.11337528, -0.11337528, -0.11337528],
         [ 1.3235099 ,  1.3235099 ,  1.3235099 ],
         ...,
         ...
'''
#在c维度上,张量已经有2个特征数据,当前维度上的这2个特征无法普适到其他位置,故不满足普适性原则,无法应用Broadcasting机制,将会触发错误
x = tf.random.normal([32,2])
tf.broadcast_to(x,[3,32,32,3])
'''
InvalidArgumentError: Incompatible shapes: [32,2] vs. [3,32,32,3] [Op:BroadcastTo]
'''
#自动Broadcasting机制
x = tf.random.normal([2,32,32,1])
y = tf.random.normal([32,32])
x+y,x-y,x*y,x/y
'''
(<tf.Tensor: shape=(2, 32, 32, 32), dtype=float32, numpy=
 array([[[[-2.34972119e-01, -1.32176340e+00, -8.74061882e-01, ...,
            9.48178649e-01,  7.73117542e-01, -1.03607702e+00],
          [-2.06189919e+00,  5.86489856e-01, -1.94167900e+00, ...,
            1.85062528e-01, -6.44589007e-01, -2.36011267e+00],
          [ 8.50200295e-01,  1.61495817e+00,  1.37277842e+00, ...,
            8.79180253e-01,  1.06644309e+00,  1.89374065e+00],
          ...,
'''

猜你喜欢

转载自blog.csdn.net/weixin_56260304/article/details/128205336
今日推荐