tensorflow 2.1 之自动求导函数tf.GradientTape

一、简介

        TensorFlow 2.1的求导计算默认是eager模式,每行代码顺序执行,没有了构建图的过程(也取消了control_dependency的用法)可能会有很大的计算量,需要一个上下文管理器(context manager)来连接需要计算梯度的函数和变量,来减少计算量。

函数: GradientTape(persistent=False,watch_accessed_variables=True)

persistent: 布尔值,用来指定新创建的gradient tape是否是可持续性的。默认是False,意味着只能够调用一次gradient()函数。
watch_accessed_variables: 布尔值,表明这个gradien tape是不是会自动追踪任何能被训练(trainable)的变量。默认是True。要是为False的话,意味着你需要手动去指定你想追踪的那些变量。

二、关联函数

        watch函数:对于不可训练的变量(比如tf.constant)可以使用tape.watch()对其进行“监控”。

        由于GradientTape默认只对tf.Variable(变量)创建的traiable=True属性(默认)的变量进行监控,故需要watch来监控constant函数创建的张量(常量),也可以设置不自动监控可训练变量,完全由自己指定,设置watch_accessed_variables=False即可。 

        若是没有watch函数则结果会是None             

watch(tensor)
作用:确保某个tensor被tape追踪

参数:tensor: 一个Tensor或者一个Tensor列表.

gradient(target,sources,output_gradients=None,unconnected_gradients=tf.UnconnectedGradients.NONE)
作用:根据tape上面的上下文来计算某个或者某些tensor的梯度
参数:target: 被微分的Tensor或者Tensor列表,你可以理解为经过某个函数之后的值
        sources: Tensors 或者Variables列表(可以只有一个值)可以理解为函数的某个变量

三、求导步骤

        1. 一阶导数

        代码:

import tensorflow as tf
x = tf.constant(3.0)

with tf.GradientTape() as g:  # 记录求导的磁带
  g.watch(x)
  y = x * x
dy_dx = g.gradient(y, x) #求导
print(dy_dx)

        结果:

        

        

        2. 二阶导数

         代码:

import tensorflow as tf

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  with tf.GradientTape() as gg:
    gg.watch(x)
    y = x * x
  dy_dx = gg.gradient(y, x)      # y’ = 2*x = 2*3 =6
d2y_dx2 = g.gradient(dy_dx, x)  # y’’ = 2

print(dy_dx)
print(d2y_dx2)

结果:

        

        3. 复合求导

        代码:

import tensorflow as tf

x = tf.constant(3.0)
# 若要多次求导需要设置函数的参数persistent=True
with tf.GradientTape(persistent=True) as g:
  g.watch(x)
  y = x * x
  z = y * y
dz_dx = g.gradient(z, x)  # z = y^2 = x^4, z’ = 4*x^3 = 4*3^3
dy_dx = g.gradient(y, x)  # y’ = 2*x = 2*3 = 6
del g  # 删除这个上下文tape

print(dy_dx)

         结果:

 注:默认情况下GradientTape的资源在调用gradient函数后就被释放,再次调用就无法计算了,若要多次求导需要设置函数的参数persistent=True

参考:

https://www.cnblogs.com/SupremeBoy/p/12246528.html

猜你喜欢

转载自blog.csdn.net/qq_46006468/article/details/119357429
2.1