nump.expand_dims() 与 tensorflow.expand_dim() 函数的异同

nump.expand_dims(array, axis),tensorflow.expand_dim(tensor, axis)

这两个expand_dims函数都是在原始数据的基础上,添加第axis维.

不同点在于处理的数据类型不同,前者是处理array类型的数据,后者是处理tensor类型的数据。

nump.expand_dims(array, axis) 用法

原始数据 

import numpy as np
x = np.array([[1, 2, 3], [4, 5, 6]])
print (x)
print (x.shape)
print ("x[0][1]: ",x[0][1])
[[1 2 3]
 [4 5 6]]
(2, 3)
x[0][1]:  2

扩展维度 

#在第0维添加1
y = np.expand_dims(x,axis=0)
print (y)
print ("y.shape: ",y.shape)
print ("y[0][1]: ",y[0][1])
print ("y[0][0][1]: ",y[0][0][1])
[[[1 2 3]
  [4 5 6]]]
y.shape:  (1, 2, 3)
y[0][1]:  [4 5 6]
y[0][0][1]:  2
#在第1维添加1
y = np.expand_dims(x,axis=1)
print (y)
print ("y.shape: ",y.shape)
print ("y[1][0]: ",y[1][0])
print ("y[0][0][1]: ",y[0][0][1])
[[[1 2 3]]

 [[4 5 6]]]
y.shape:  (2, 1, 3)
y[1][0]:  [4 5 6]
y[0][0][1]:  2
#在第3维添加1
y = np.expand_dims(x,axis=2)
print (y)
print ("y.shape: ",y.shape)
print ("y[1][0]: ",y[1][0])
print ("y[0][1][0]: ",y[0][1][0])
[[[1]
  [2]
  [3]]

 [[4]
  [5]
  [6]]]
y.shape:  (2, 3, 1)
y[1][0]:  [4]
y[0][1][0]:  2

tensorflow.expand_dim(tensor, axis) 用法 

原始数据

import tensorflow as tf
x = tf.Variable([[1, 2, 3], [4, 5, 6]])
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    value = sess.run(x)
    print (value)
    print (x.shape)
    print ("x[0][1]: ",value[0][1])
[[1 2 3]
 [4 5 6]]
(2, 3)
x[0][1]:  2

扩展维度

#在第0维添加1
y = tf.expand_dims(x,axis=0)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    value = sess.run(y)
    print (value)
    print (y.shape)
    print ("y[0][1]: ",value[0][1])
[[[1 2 3]
  [4 5 6]]]
(1, 2, 3)
y[0][1]:  [4 5 6]

猜你喜欢

转载自blog.csdn.net/Muzi_Water/article/details/81388393