tf.reshap() tf.shape(x)与x.get_shape()

 tf.rehspe()用法 : tf.reshape(tensor, shape, name=None)

>>> import tensorflow as tf
>>> import numpy as np

## create an array a

>>> a = np.arange(24)
>>> a
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23])

① In reshape, the shape is passed in as a list by default

>>> tf.reshape(a,[12,2])
<tf.Tensor 'Reshape:0' shape=(12, 2) dtype=int32>

>>> sess.run(tf.reshape(a,[4,6]))
array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],

       [18, 19, 20, 21, 22, 23]])

② -1 usage in shape in reshape

## The number of rows is fixed at 4, and the number of columns is calculated by default

>>> sess.run(tf.reshape(a,[4,-1]))
array([[ 0, 1, 2, 3, 4, 5],
       [ 6, 7, 8, 9, 10, 11 ],
       [12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23]])

## The number of columns is fixed at 6, and the number of rows is calculated by default
>>> sess.run(tf .reshape(a,[-1,4]))
array([[ 0, 1, 2, 3],
       [ 4, 5, 6, 7],
       [ 8, 9, 10, 11],
       [12, 13 , 14, 15],
       [16, 17, 18, 19],
       [20, 21, 22, 23]])

③ The meaning of the three parameters passed in the shape in reshape

>>> sess.run(tf.reshape(a,[2,3,4]))
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])

 tf.shape(a) usage and a.get_shape()

① Both functions can reach the size of tensor

② The data type in tf.shape(a) can be tensor, array, list but a.get_shape can only be tensor, and the return value is a tuple

## Create an array a
>>> a = np.arange(24)
>>> a
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23])

## 从a的array中创建一个tensor b
>>> b = tf.reshape(a,[4,6])
>>> b
<tf.Tensor 'Reshape_11:0' shape=(4, 6) dtype=int32>
>>> sess.run(b)
array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],

       [18, 19, 20, 21, 22, 23]])

## create a list c
>>> c = [1,2,3]
>>> c
[1, 2, 3]
Let's look at the output of the tf.shape(x) function a,b,c:

>>> sess.run(tf .shape(a))
array([24])
>>> sess.run(tf.shape(b))
array([4, 6])
>>> sess.run(tf.shape(c))
array( [3])
Let's look at the output of the x.get_shape() function acting on a, b, c:

>>> a.get_shape()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>

AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'

>>> b.get_shape()

TensorShape([Dimension(4), Dimension(6)])

>>> c.get_shape()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>

AttributeError: 'list' object has no attribute 'get_shape'

Explain that in x.get_shape(), x can only be a tensor, otherwise an error is reported, and a tuple is returned, which can take out rows and columns respectively:

>>> b.get_shape()
TensorShape([Dimension(4), Dimension(6)])
>>> print(b.get_shape())
(4, 6)
>>> print( b.get_shape()[0])
4
>>> print( b.get_shape()[1])
6
>>> b.get_shape()[0].value
4
>>> b.get_shape()[1].value
6

Reference: https://blog.csdn.net/fireflychh/article/details/73611021

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325370739&siteId=291194637