记录 之 tensorflow中几个常用的函数:tf.unstack,tf.concat() 和 tf.stack() 等

1.tf.to_int32();tf.to_float()等 函数,主要是强制类型转换函数;

2.tf.shape(tensor);获取tensor的尺寸

3.tf.round(a);四舍五入函数,张量的值四舍五入为最接近的整数

4.tf.unstack(matrix, axis  =  ‘ ’ );矩阵分解函数
matrix:需要拆解的矩阵
axis:沿某一维度进行拆解,值得注意的是,在使用的时候,axis = 不可缺省,axis 取值范围是[-a,a),a是matrx的维度

例:

import tensorflow as tf

mat = tf.constant([1,2,3],[4,5,6])

o1 = tf.unstack(mat,axis = 0)

o2 = tf.unstack(mat,axis = 1)

sess = tf.Session()

print(sess.run(o1))

>>>[array([1, 2, 3]), array([3, 4, 5])]

>>>[array([1, 3]), array([2, 4]), array([3, 5])]

5.tf.concat() 和 tf.stack() 是我们需要重点区分的两个函数,两者讲道理功能都是进行张量拼接,但是tf.concat函数不会在原有的基础上增加维度,所以在进行通道拼接时常选用tf.concat。而tf.stack()函数在原有的基础上增加一个维度,即由原来的n维变换为n+1维,我们分别来看两个例子:

(1). tf.concat()

import tensorflow as tf

a = tf.constant([[[1,2],[2,3],[3,4]], [[5,6],[7,8],[8,9]]]) #[1,3,2]
b = tf.constant([[[11,12],[12,13],[13,14]], [[15,16],[17,18],[18,19]]]) #[1,3,2]

o1 = tf.concat([a,b], axis = 0)

o2 = tf.concat([a,b], axis = 1)

o3 = tf.concat([a,b], axis = 2)

sess = tf.Session()

print(sess.run(o1))

print(sess.run(o2))

print(sess.run(o3))

output:

>>>[[[ 1  2]
  [ 2  3]
  [ 3  4]]

 [[ 5  6]
  [ 7  8]
  [ 8  9]]

 [[11 12]
  [12 13]
  [13 14]]

 [[15 16]
  [17 18]
  [18 19]]]      #维度:[2,3,2]
>>>[[[ 1  2]
  [ 2  3]
  [ 3  4]
  [11 12]
  [12 13]
  [13 14]]

 [[ 5  6]
  [ 7  8]
  [ 8  9]
  [15 16]
  [17 18]
  [18 19]]]      #维度:[1,6,2]
>>>[[[ 1  2 11 12]
  [ 2  3 12 13]
  [ 3  4 13 14]]

 [[ 5  6 15 16]
  [ 7  8 17 18]
  [ 8  9 18 19]]]  #维度:[1,3,4]

(2).tf.stack()

import tensorflow as tf

a = tf.constant([[1,2],[2,3],[3,4]]) #[3,2]
b = tf.constant([[5,6],[7,8],[8,9]]) #[3,2]

o1 = tf.stack([a,b], axis = 0)

o2 = tf.stack([a,b], axis = 1)

o3 = tf.stack([a,b], axis = 2)

sess = tf.Session()

print(sess.run(o1))

print(sess.run(o2))

print(sess.run(o3))

output:
>>>[[[1 2]
     [2 3]
     [3 4]]

    [[5 6]
     [7 8]
     [8 9]]] #维度:[2,3,2]

>>>[[[1 2]
     [5 6]]

    [[2 3]
     [7 8]]

    [[3 4]
     [8 9]]] #维度:[3,2,2]

>>>[[[1 5]
     [2 6]]

    [[2 7]
     [3 8]]

    [[3 8]
     [4 9]]] #维度:[3,2,2]

这个变换过程手动操作可以理解,以axis = 1 为例:

先将a变换为维度为[3,1,2]的矩阵:即

[[[1,2]],
 [[2,3]],
 [[3,4]]] 

同理b做同样的变换,即

[[[5,6]],
 [[6,7]],
 [[7,8]]]

然后就做类似concat的操作,即成为了:

[[[1 2]
  [5 6]]

 [[2 3]
  [7 8]]

 [[3 4]
  [8 9]]]

到这里我们就知道了tf.caoncat 和 tf.stack的区别及作用,也学会了手动操作。


 

猜你喜欢

转载自blog.csdn.net/qq_41368074/article/details/109998502
今日推荐