一、简介
TensorFlow 2.1的求导计算默认是eager模式,每行代码顺序执行,没有了构建图的过程(也取消了control_dependency的用法)可能会有很大的计算量,需要一个上下文管理器(context manager)来连接需要计算梯度的函数和变量,来减少计算量。
函数: GradientTape(persistent=False,watch_accessed_variables=True)
persistent: 布尔值,用来指定新创建的gradient tape是否是可持续性的。默认是False,意味着只能够调用一次gradient()函数。
watch_accessed_variables: 布尔值,表明这个gradien tape是不是会自动追踪任何能被训练(trainable)的变量。默认是True。要是为False的话,意味着你需要手动去指定你想追踪的那些变量。
二、关联函数
watch函数:对于不可训练的变量(比如tf.constant
)可以使用tape.watch()
对其进行“监控”。
由于GradientTape默认只对tf.Variable(变量)创建的traiable=True属性(默认)的变量进行监控,故需要watch来监控constant函数创建的张量(常量),也可以设置不自动监控可训练变量,完全由自己指定,设置watch_accessed_variables=False即可。
若是没有watch函数则结果会是None
watch(tensor)
作用:确保某个tensor被tape追踪参数:tensor: 一个Tensor或者一个Tensor列表.
gradient(target,sources,output_gradients=None,unconnected_gradients=tf.UnconnectedGradients.NONE)
作用:根据tape上面的上下文来计算某个或者某些tensor的梯度
参数:target: 被微分的Tensor或者Tensor列表,你可以理解为经过某个函数之后的值
sources: Tensors 或者Variables列表(可以只有一个值)可以理解为函数的某个变量
三、求导步骤
1. 一阶导数
代码:
import tensorflow as tf
x = tf.constant(3.0)
with tf.GradientTape() as g: # 记录求导的磁带
g.watch(x)
y = x * x
dy_dx = g.gradient(y, x) #求导
print(dy_dx)
结果:
2. 二阶导数
代码:
import tensorflow as tf
x = tf.constant(3.0)
with tf.GradientTape() as g:
g.watch(x)
with tf.GradientTape() as gg:
gg.watch(x)
y = x * x
dy_dx = gg.gradient(y, x) # y’ = 2*x = 2*3 =6
d2y_dx2 = g.gradient(dy_dx, x) # y’’ = 2
print(dy_dx)
print(d2y_dx2)
结果:
3. 复合求导
代码:
import tensorflow as tf
x = tf.constant(3.0)
# 若要多次求导需要设置函数的参数persistent=True
with tf.GradientTape(persistent=True) as g:
g.watch(x)
y = x * x
z = y * y
dz_dx = g.gradient(z, x) # z = y^2 = x^4, z’ = 4*x^3 = 4*3^3
dy_dx = g.gradient(y, x) # y’ = 2*x = 2*3 = 6
del g # 删除这个上下文tape
print(dy_dx)
结果:
注:默认情况下GradientTape的资源在调用gradient函数后就被释放,再次调用就无法计算了,若要多次求导需要设置函数的参数persistent=True
参考: