Tensorflow: 动态的给变量tf.Variable赋值【tf.assign】

Motivation

错误:
tensorflow不能直接给Variable赋值,比如:

embedding_var = tf.Variable(1)
test_var = 10
embedding_var = test_var

或者:

embedding_var = tf.Variable(1)
init = tf.initialize_all_variables()
sess = tf.InteractiveSession()
sess.run(init)
x.assign(1)

解决方法

正确:
如果只需要给Variable赋值一次,可以通过assign这样进行赋值:

import tensorflow as tf
x = tf.Variable(0)
y = tf.assign(x, 1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(x)
    print sess.run(y)
    print sess.run(x)

但是通常赋一次值的意义不大,因为有时我们想将网络中的一些输出通过saver()保存下来,或者通过tensorboard查看网络的embedding投影,那么就需要将网络中产生的输出以变量的形式储存,这样就可以在saver.save()的时候将这些输出给保存到本地,又因为tensorflow不能在图外面直接对变量进行操作,所以我通过用一个占位符来传输网络的输出结果,然后再session里面取出网络的输出值,feed给该占位符,然后将占位符的值赋给一个临时变量作为保存,如下,亲测有效:

flat_value = np.zeros([200,4*4*32]) 
mid_vari = tf.placeholder(tf.float32, [200,4*4*32],name="mid_vari")
embedding_var = tf.Variable(tf.zeros([200,4*4*32]), name=NAME_TO_VISUALISE_VARIABLE)
mid_vari_2 = tf.assign(embedding_var,mid_vari)

with tf.Session() as sess:
    saver =  tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    for i in range(200):
        flat_value,_=sess.run([flat,mid_vari_2],feed_dict={x:one_x,y:labels,mid_vari:flat_value})

比较周折,不过也是试了很多办法才找到的解决方案T_T。

参考

https://blog.csdn.net/mustar_2017/article/details/79336679

猜你喜欢

转载自blog.csdn.net/LiGuang923/article/details/83833545