theano学习之Shared变量

Shared 变量,意思是这些变量可以在运算过程中,不停地进行交换和更新值。 在定义 weights 和 bias 的情况下,会需要用到这样的变量。

import numpy as np
import theano
import theano.tensor as T

#---------------------------shared是用来存放变量的,会不断更新数值---------------------#

state = theano.shared(np.array(0,dtype=np.float64),'state')#用np.array给state赋初值,名字为state
inc = T.scalar('inc',dtype=state.dtype)#定义一个容器,名字为inc,值的类型为state类型,不弄用np.float64
accumulator = theano.function([inc],state,updates=[(state,state+inc)])#定义一个函数,传过来的数为inc的值,结果为state,更新方法为state加inc

#--------输出不能使用print(state),而要用state.get_value()来获取state中个值---------#

print(state.get_value())#不传值的话state里面的值为赋初值里的值
accumulator(1)#把1传到累加器里面去
print(state.get_value())#state的值更新了,变为1
accumulator(10)#把10传过去
print(state.get_value())#state的值更新了,变为11

#---------------可以用set_value来改变state里的值----------------------------------------#
state.set_value(-1)
accumulator(3)
print(state.get_value())#输出值为2

#-------------------------------------临时使用-------------------------------------------#
#有时只是想暂时使用 Shared 变量,并不需要把它更新: 这时我们可以定义一个 a 来临时代替 state,注意定义 a 的时候也要统一dtype
tmp_func = state*2 + inc
a = T.scalar(dtype=state.dtype)
skip_shared = theano.function([inc,a],tmp_func,givens=[(state,a)])#忽略掉 Shared 变量自己的运算,输入值是 [inc,a],相当于把 a 代入 state,输出是 tmp_func,givens 就是想把什么替换成什么。 这样的话,在调用 skip_shared 函数后,state 并没有被改变。
print(skip_shared(2,3))#借用了一下share变量state
print(state.get_value())#原始值还是2



结果:

来源

猜你喜欢

转载自blog.csdn.net/weixin_40849273/article/details/84579935