pytorch0.4使用注意

1.梯度
1.Variable()中,requires_grad=Fasle时不需要更新梯度, 适用于冻结某些层的梯度;
2.volatile=True相当于requires_grad=False,适用于测试阶段,不需要反向传播。在torch>=0.4中,这个现在已经取消了,使用with torch.no_grad()或者torch.set_grad_enable(grad_mode)来替代:

with torch.no_grad():
  test()
>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
...     y = x * 2
>>> y.requires_grad
False
>>>
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...     y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True)  # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

2.tensor与Variable
Tensor在0.4中,现在默认requires_grad=False的Variable了,相当于(tensor 等价于Variable(Tensor,requires_grad=Fasle)),
torch.Tensor和torch.autograd.Variable现在其实是同一个类! 没有本质的区别! 所以也就是说, 现在已经没有纯粹的Tensor了, 是个Tensor, 它就支持自动求导! 你现在要不要给Tensor包一下Variable都没有任何意义了
下面是0.4中一些新建tensor的方法

#0.4中建立一个tensor:
>>> device = torch.device("cuda:1")
>>> x = torch.randn(3, 3, dtype=torch.float64, device=device)
tensor([[-0.6344,  0.8562, -1.2758],
        [ 0.8414,  1.7962,  1.0589],
        [-0.1369, -1.0462, -0.4373]], dtype=torch.float64, device='cuda:1')
>>> x.requires_grad  # default is False
False
>>> x = torch.zeros(3, requires_grad=True)
>>> x.requires_grad
True

3.requires_grad 已经是Tensor的一个属性了
举个例子:

>>> x = torch.ones(1)
>>> x.requires_grad #默认是False
False
这里也说明了tensor就是一个requires_grad=False的Variable

4.不要随便用.data
在torch0.3中,Variable分为tensor和grad两项,通过.data取出Variable中的Tensor,torch0.4变了.
torch0.4中,.data返回的是一个tensor,但是现在这个tensor是一个有requires_grad(可以自动求导)的tensor,而且现在.data取出的tensor和之前的Variable是内存共享,所以不安全.

y = x.data # x需要进行autograd的
# y和x是共享内存的,但是这里y已经不需要grad了, 
# 所以会导致本来需要计算梯度的x也没有梯度可以计算.从而x不会得到更新!

为了解决上面的风险:所以, 推荐用x.detach(), 这个仍旧是共享内存的, 也是使得y的requires_grad为False, 但是,如果x需要求导, 仍旧是可以自动求导的!

y = x.datach() # x需要进行autograd的
y和x也是共享内存,并且y的requires_grad为False,但是,如果x需要求导, 仍旧是可以自动求导的!

5..item()
以前取tensor的值用.data,现在用.item()
比如为了显示loss到命令行:以前了累加loss(为了看loss的大小)一般是用total_loss+=loss.data[0] , 比较诡异的是, 为啥是.data[0]? 这是因为, 这是因为loss是一个Variable, 所以以后累加loss, 用loss.item().

6.弃用 volatile(同最开始的梯度解释)

参考:https://www.itency.com/topic/show.do?id=494122

猜你喜欢

转载自blog.csdn.net/CV_YOU/article/details/84591987