tensorflow:常用API-'a'

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/jiangpeng59/article/details/77473471

1.加法操作
tf.accumulate_n、tf.add_n、tf.add

import tensorflow as tf
a = tf.constant([[1, 2], [3, 4]])
b = tf.constant([[5, 0], [0, 6]])
c = tf.constant([2, 3])

sess = tf.InteractiveSession()

print(tf.accumulate_n([a, b]).eval())
print(tf.add_n([a, b]).eval())
print(tf.add(a, b).eval())
# 输出的结果都是
# [[ 6  2]
# [ 3 10]]

# 但是tf.add支持broadcasting
print(tf.add(a, c).eval())
# [[6 2]
# [8 4]]
# print(tf.add_n([a,c]).eval()) #不支持广播-error

小结:多个tensor对应相加,推荐tf_add_n,若需要支持广播(2个shape不一样的tensor进行操作),请使用tf.add

2.argmax和argmin

a1 = tf.constant([[4, 2, 3], [1, 6, 5]])
print(tf.argmax(a1, axis=0).eval())
# 默认,按列取最大值的下标[0 1 1]
print(tf.argmax(a1, axis=1).eval())
# 按行取最大值的下标[0 1]

tf还提供了arg_max函数,其功能和argmax一样,arg_max是一个待抛弃的函数,推荐使用argmax,argmin和argmax类似

3.assign
assign对tensor的引用进行重新赋值

a2 = tf.Variable(3, dtype=tf.float32)
sess.run(tf.global_variables_initializer())
print(a2.eval())  # 3
tf.assign(a2, 5).eval()  # 必须得eval执行下
print(a2.eval())  # 5
tf.assign_add(a2,2).eval() #加上一个值再从新赋值
print(a2.eval())  # 7

4.as_string
类似tostring函数,不过作用在tensor上

a3 = tf.constant([[1.13, 2.02], [3.5, 4.433]])
print(tf.as_string(a3,precision=2).eval()) #保留2位小数
# [[b'1.13' b'2.02']
# [b'3.50' b'4.43']]

猜你喜欢

转载自blog.csdn.net/jiangpeng59/article/details/77473471