pytorch 中的variable函数

torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现(tensor变成variable之后才能进行反向传播求梯度?用变量.backward()进行反向传播之后,var.grad中保存了var的梯度)

x = Variable(tensor, requires_grad = True)

Varibale包含三个属性:

  • data:存储了Tensor,是本体的数据
  • grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
  • grad_fn:指向Function对象,用于反向传播的梯度计算之用

用法:

  1. import torch
  2. from torch.autograd import Variable
  3. x = Variable(torch.one( 2,2), requires_grad = True)
  4. print(x) #其实查询的是x.data,是个tenso

猜你喜欢

转载自www.cnblogs.com/lyp1010/p/12146519.html