【PyTorch】1.2 PyTorch中的autograd

学习目标:

  • 掌握自动求导中的Tensor概念和操作
  • 掌握自动求导中的梯度Gradients概念和操作
  • 在整个Pytorch框架中,所有的神经网络本质上都是一个autograd package(自动求导工具包)
  • autograd package提供了一个对Tensors上所有的操作进行自动微分的功能

关于torch.Tensor

  • torch.Tensor是整个package中的核心类,如果将属性.requires_grad设置为True,它将追踪在这个类上定义的类上定义的所有操作,当代码要进行反向传播的时候,直接调用.backward()就可以自动计算所有的梯度。在这个Tensor上的所有梯度将被累加进属性.grad中。
  • 如果想终止一个Tensor在计算图中的追踪回溯,只需要执行.detach()就可以将该Tensor从计算图中撤下,在未来的慧诉计算中也不会在计算该Tensor
  • 除了detach()如果想终止对计算图的回溯,也就是不再进行方向传播求导数的过程,也可以采用代码块的方式with torch.no_grad():,这种方式非常适用于对模型进行预测的时候,因为预测阶段不再需要对梯度进行计算

torch.Function

  • Function类是和Tensor类同等重要的一个核心类,它和Tensor共同构建了一个完整的类,每一个Tensor拥有一个.grad_fn属性,代表引用量能够具体的Function创建了该Tensor
  • 如果某个张量Tensor是用户自定义的,则其对应的grad_fn is None

关于Tensor的操作

import torch
x1=torch.ones(3,3)
print(x1)

x=torch.ones(2,2,requires_grad=True)
print(x)

#在具有requires_grad=True的Tensor上进行加法操作
y=x+2
print(y)
print(x.grad_fn)
print(y.grad_fn)
a=torch.randn(2,2)
a=((a*3)/(a-1))
print(a)
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
print(a)
b=(a*a).sum()
print(b)
print(b.grad_fn)
#输出
tensor([[ 1.3117, -2.9088],
        [ 1.4760,  1.3805]])
False
True
tensor([[ 1.3117, -2.9088],
        [ 1.4760,  1.3805]], requires_grad=True)
tensor(14.2663, grad_fn=<SumBackward0>)
<SumBackward0 object at 0x00000200375D3CC8>

关于梯度Gradients

在Pytorch中,反向传播是依靠.backward()实现的

out.backward()

完整代码

import torch
x1=torch.ones(3,3)
print(x1)

x=torch.ones(2,2,requires_grad=True)
print(x)

#在具有requires_grad=True的Tensor上进行加法操作
y=x+2
print(y)
print(x.grad_fn)
print(y.grad_fn)

z=y*y*3
out=z.mean()
print(z)
print(out)
out.backward()
print(x.grad)

关于自动求导的属性设置

可以通过设置.requires_grad=True来执行自动求导,也可以通过代码块的限制来停止自动求导

print(x.requires_grad)
print((x**2).requires_grad)
with torch.no_grad():
    print((x**2).requires_grad)
#输出
True
True
False

可以通过.detach()获得一个新的Tensor,拥有相同的内容但不需要自动求导

print(x.requires_grad)
y=x.detach()
print(y.requires_grad)
print(x)
print(y)
print(x.eq(y).all())
#输出
True
False
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
tensor([[1., 1.],
        [1., 1.]])
tensor(True)

避免在代码里大量使用detach

猜你喜欢

转载自blog.csdn.net/kz_java/article/details/121406874