tensorflow函数之tf.cast,tf.argmax,tf.argmin,tf.equal

1.tf.cast

官方链接:https://www.tensorflow.org/api_docs/python/tf/cast

tf.cast(
    x,
    dtype,
    name=None
)

将一个张量转换为一个新类型。

Casts a tensor to a new type.

举例:

import tensorflow as tf  
a = tf.Variable([1,0,0,1,0])  
b = tf.cast(a,dtype = bool)  
sess = tf.Session()  
sess.run(tf.initialize_all_variables())  
print(sess.run(b))
输出:
[ True False False  True False]

2.tf.argmax

官方链接:https://www.tensorflow.org/api_docs/python/tf/argmax

tf.argmax(
    input,
    axis=None,
    name=None,
    dimension=None,
    output_type=tf.int64
)

返回最大值所在的坐标。

Returns the index with the largest value across axes of a tensor.

举例:

import tensorflow as tf  
import numpy as np  

A = [[1,3,4,5,6]]  
B = [[1,3,4], [2,4,1]]  

with tf.Session() as sess:  
    print(sess.run(tf.argmax(A, 1)))  
    print(sess.run(tf.argmax(B, 1)))  

输出:

[4]
[2 1]

3.tf.argmin

官方链接:https://www.tensorflow.org/api_docs/python/tf/argmin

tf.argmin(
    input,
    axis=None,
    name=None,
    dimension=None,
    output_type=tf.int64
)

返回最小值所在的坐标。

Returns the index with the smallest value across axes of a tensor.

举例:

import tensorflow as tf  
import numpy as np  

A = [[1,3,4,5,6]]  
B = [[1,3,4], [2,4,1]]  

with tf.Session() as sess:  
    print(sess.run(tf.argmin(A, 1)))  
    print(sess.run(tf.argmin(B, 1)))  

输出:

[0]
[0 2]
4.tf.equal()

官方链接:https://www.tensorflow.org/api_docs/python/tf/equal

tf.equal(
    x,
    y,
    name=None
)
tf.equal(A, B)是对比这两个矩阵或者向量的相等的元素,如果是相等的那就返回True,反正返回False,返回的值的矩阵维度和A是一样的

举例:

import tensorflow as tf  
import numpy as np  

A = [[1,3,4,5,6]]  
B = [[1,3,4,3,2]]  

with tf.Session() as sess:  
    print(sess.run(tf.equal(A, B)))  

输出:

[[ True  True  True False False]]

猜你喜欢

转载自blog.csdn.net/weixin_41278720/article/details/80269757