pytorch怎么使用任一Variable生成另外一个叶节点Variable

时间:2018/2/7

一般情况下,用一个完整的网络就可以了。但是像我现在要做一个network in network,大网络里要附加一个小网络,而且还想单独训练这个小网络。使用pytorch实现这个想法的时候问题就来了:pytorch只对Variable叶节点有显式的梯度计算,所以任何其他的操作clone等都不能计算梯度的。而又不想训练这个单独的小网络的时候,将梯度传递到主网络里面去,所以有以下两种方法。

方法一:

使用Variable.detach() detach的官网介绍

比如模拟了一个网络的代码如下:

import torch
import numpy as np
from torch.autograd import Variable
import torch.nn as nn

a = np.arange(24).reshape(1, 2, 3, 4).astype(np.float32)
b = Variable(torch.from_numpy(a.astype(np.float32)))
d = b.clone().cuda() + 1
b.requires_grad = True



x1 = nn.Conv2d(2, 2, 1).cuda()
x2 = nn.ReLU().cuda()
l = nn.MSELoss()

x = x1(b.cuda())
c = x.detach()
c.requires_grad = True
xc = x1(c)

ls = l(xc, d)
ls.backward()

print b.grad
print '*' * 10
print c.grad

方法二:

使用clone()

这样得到的变量的确不是叶节点,但是小网络里面用到的权值和偏置项是叶节点,pytorch还是能够将接下来的计算用于求算权重的梯度.

哈哈,折腾了将近一小时,不如脑袋短路一秒钟。
However,使用detach可以将graph结构就此断路,不会继续前传,clone则是不行的。

猜你喜欢

转载自blog.csdn.net/daniaokuye/article/details/79282657