从零学习Pytorch 什么是Variable?

Tensor是Pytorch的一个完美组件(可以生成高维数组),但是要构建神经网络还是远远不够的,我们需要能够计算图的Tensor,那就是Variable。Variable是对Tensor的一个封装,操作和Tensor是一样的,但是每个Variable都有三个属性,Varibale的Tensor本身的.data,对应Tensor的梯度.grad,以及这个Variable是通过什么方式得到的.grad_fn。

# 通过一下方式导入Variable
from torch.autograd import Variable
import torch
x_tensor = torch.randn(2,5)
y_tensor = torch.randn(2,5)
#将tensor转换成Variable
x = Variable(x_tensor) #Varibale 默认时不要求梯度的,如果要求梯度,需要说明
y = Variable(y_tensor,requires_grad=True)
print(x.grad)
print(y.grad)
print('-'*20)
z = torch.sum(x + y)
print(z)
print('data',z.data)
print('grad',z.grad)
print('grad_fn',z.grad_fn)
z.backward()

在这里插入图片描述

发布了78 篇原创文章 · 获赞 14 · 访问量 9720

猜你喜欢

转载自blog.csdn.net/qq_34107425/article/details/104128134
今日推荐