pytorch强大的自动求导功能

import torch

x = torch.tensor(2.,requires_grad=True)  #requires_grad=True不能丢,因为默认是False,第一个参数一定得是float形式才能求导
w = torch.tensor(1.,requires_grad=True)
b = torch.tensor(3.,requires_grad=True)

y = w*x + b

y.backward()    #输出为torch.Size([]),所以标量是0维
print(x.shape)
print(w.grad)
print(x.grad)
print(b.grad)
print(y.grad)

输出如下,对向量和矩阵也能自动求导惹

torch.Size([])
tensor(2.)
tensor(1.)
tensor(1.)
None

猜你喜欢

转载自www.cnblogs.com/loubin/p/12738529.html