《Gluon 动手学深度学习 三》MXNet autograd 自动求导

• MXNet 提供 autograd 包来⾃动化求导过程。

• MXNet 的 autograd 包可以对正常的命令式程序进⾏求导。

from mxnet import autograd,nd

#创建变量,并复制
x = nd.arange(4).reshape((4,1))

#先使用attach_grad()为变量梯度申请内存
x.attach_grad()

#定义有关变量x的函数。默认条件下,为了减少计算和内存开销,MXNet不会记录用于求梯度的计算图。我们需要调用record函数来要求MXNet记录与求梯度有关的计算。
#也就是说不能使用之前定义的函数直接求导(分割线以后内容为错误示例)
with autograd.record():
    y = 2 * nd.dot(x.T, x)

#由于x的形状为(4, 1),y是一个标量。接下来我们可以通过调用backward函数自动求梯度。需要注意的是,如果y不是一个标量,MXNet将先对y中元素求和得到新的变量,再求该变量有关x的梯度。
y.backward()

#输出结果
x.grad, x.grad == 4*x # 1为真,0为假。

==========分割线===========

#也就是说不能使用之前定义的函数直接求导(以下内容为错误示例)
from mxnet import autograd, nd
n=nd.arange(4).reshape((4,1))
print(n)
m=(nd.dot(n.T,n))
print(m)
n.attach_grad()

with autograd.record():
    #m=(nd.dot(n.T,n))
    m
m.backward()
n.grad,n.grad ==2*x
[[ 0.]
 [ 1.]
 [ 2.]
 [ 3.]]
<NDArray 4x1 @cpu(0)>

[[ 14.]]
<NDArray 1x1 @cpu(0)>
---------------------------------------------------------------------------
MXNetError                                Traceback (most recent call last)
<ipython-input-29-e592fee224e5> in <module>()
      9     #m=(nd.dot(n.T,n))
     10     m
---> 11 m.backward()
     12 n.grad,n.grad ==2*x
     13 

C:\Anaconda2\envs\gluon\lib\site-packages\mxnet\ndarray\ndarray.py in backward(self, out_grad, retain_graph, train_mode)
   2094             ctypes.c_int(train_mode),
   2095             ctypes.c_void_p(0),
-> 2096             ctypes.c_void_p(0)))
   2097 
   2098     def tostype(self, stype):

C:\Anaconda2\envs\gluon\lib\site-packages\mxnet\base.py in check_call(ret)
    147     """
    148     if ret != 0:
--> 149         raise MXNetError(py_str(_LIB.MXGetLastError()))
    150 
    151 

MXNetError: [18:29:55] C:\projects\mxnet-distro-win\mxnet-build\src\imperative\imperative.cc:373: Check failed: !AGInfo::IsNone(*i) Cannot differentiate node because it is not in a computational graph. You need to set is_recording to true or use autograd.record() to save computational graphs for backward. If you want to differentiate the same graph twice, you need to pass retain_graph=True to backward.


猜你喜欢

转载自blog.csdn.net/qq_42189368/article/details/80719967