【PyTorch 计算图】requires_grad=True的leaf variable及其设计逻辑

内容简介

最近在写pytorch代码的时候遇到了这种情况:

  1. 一个RNN层,我们这里把这个层的参数记做 p a r a m r n n param_{rnn} paramrnn
  2. 一个输入X,得到输出 output, h_n = RNN(X)。这里output存的是每个节点的隐层,h_n是最后一层的隐层,具体去看pytorch的docs
  3. 有时候我们需要的是一个batch内不同seq_lens位置的的隐层输出,而非直接拿padding后的最后的隐层h_n,这个时候我们通常会进行这样的操作:
def forward(self, X):
	output, h_n = RNN(X)
	real_output = torch.zeros((output.shape[0], self.hidden_size))
	for idx, crnt_len in enumerate(seq_lens):
		real_output[idx] = output[idx, crnt_len-1].clone()

我们先做了一个空的tensor容器,每次往里面放对应位置的隐层向量。这个代码结构我们在很多涉及RNN的模型中看到过,但是细细想来实际上有一些问题值得我们深究:

  • 这个real_output从头至尾都没有被主动放到cuda上,为什么后面它在跟cuda的tensor交互的时候没有报错?进一步,如果你尝试在上面代码的第三行后面加一句.to(device),pytorch会在你执行for循环内容的时候给你报错:RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

主体

什么是leaf variable

在上面的报错中,我们看到有一个可能看起来比较陌生的词leaf variable,具体要知道是怎么回事我们需要先看看一个概念叫计算图(computaton gragh),这是pytorch通过向后传播进行自动微分的一个重要结构。如果想要具体了解,可以看看pytorch之前发过的一个blog

具体而言,计算图的节点可以分为几类:parameters,输入(X),中间的计算结果,以及最终的loss。对于这几种节点,我们可以发现,如果我们以loss节点为根,把整张计算图“拎起来”,会发现又一些节点处于最高的位置——没有其他节点在它上面。如果我们把loss节点看作树根,这些“最高的节点”其实就是树叶的位置。

进一步我们对这些“树叶”节点进行讨论,因为中间节点有其他节点在它上面,又不是loss节点,它从属的类只有两种可能:

  1. 需要更新的参数,即parameters
  2. 输入,即X,不需要grad来更新

对于pytorch这个结构而言,它关心的从来都是自己结构内部的参数,由使用者定义的,各不相同的输入不在他的重心范围内,所以他只把需要跟新的parameters叫做“树叶节点”——即leaf variable


换言之,我们用代码的语言来说,就是requires_grad=True的那些tensor。


这里我们需要进行说明,虽然有时候计算图不一定是数学意义下严格的树结构,但是这样有助于我们进行阐述。

为什么不用.to(device)

如果自己尝试过会发现,在进行部分的in_place操作后,整个被赋值的tensor都会被自动迁移到数据源tensor的device上。

为什么会报这个错

我希望这里从模型设计的底层逻辑上进行阐述。

这个leaf variable正如我们先前所说,本质上只是一个attribute不太一样的tensor,但是我们还需要考虑到它的意义内涵。这个requires_grad=True实际上的表意是:我这个tensor代表的是模型的参数,而模型的参数在设计上来说,在被初始化之后,只能通过梯度更新来改变tensor内部的值。

我认为这才是pytorch设置这个报错的原因,当我们尝试通过赋值对leaf variable内部进行值进行修改的时候,我们实质上破坏的是模型的设计逻辑,我们能且只能通过参数更新来改变leaf variable内部的数值。

怎么改

重新审视自己模型设计的逻辑,看这块的tensor是不是实际上只是一个中间的过渡节点,实际上并不需要参数更新,进一步不需要requires_grad=True。

而如果确实是parameters,那么你需要重新审视为什么你觉得这里需要进行赋值,如果是初始化,我建议可以先进行赋值,再把requires_grad置为True。

猜你喜欢

转载自blog.csdn.net/Petersburg/article/details/127604500