Pytorch 常见报错 RuntimeError: Trying to backward through the graph a second time

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.

model = RNN()
hn = torch.zeros(1,seq_len,hidden_num)
epochs = 250
clip_value = 100
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

for epoch in range(epochs):
    accu,num = 0.0,0
    for x,y in data_collect(corpus_indice,batch_size,seq_len):
		
		#-------------------------------------------------------------#
		# 这里添加上一句话即可
		hn.detach_()
		#-------------------------------------------------------------#
        output,hn = model(x,hn)
        y = y.transpose(1,0).contiguous().view(-1)
        ls = loss(output.view(-1,vocab_len),y)
        
        optimizer.zero_grad() 
        ls.backward()
        
        torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)
        
        optimizer.step()
        accu += ls.item() * y.shape[0]
        num += y.shape[0]
    if epoch%50 == 0:
        print("现在是第{}次epoch,loss的值为{}".format(epoch,math.exp(accu/num)))
print("完成")
  • 产生问题的原因(我个人的理解是这样的):
    每一次在读取并计算完一个 b a t c h batch batch 的数据之后会有一个 h t h_t ht 连接到计算图中,这个 h t h_t ht 参与反向传播求梯度,梯度在求完结果之后就释放掉了,当到下一个也就是 h t + 1 h_{t+1} ht+1 的时候,在计算梯度时还会经过 h t h_t ht(因为这两个在计算图中连着),但是 h t h_t ht 的相关信息已经被释放了,所以会产生报错。
  • 解决方法:
    • 第一种解决方案:
      hidden.detach_() 在每次读取一个 b a t c h batch batch 的数据之后,开始训练之前都要将隐藏层从计算图中 d e t a c h detach detach 出来。
    • 第二种解决方案:
      loss.backward() 替换成 loss.backward(retain_graph=True),这种方法相当于将计算图的全部内容都存储下来,中间不进行清零操作。很明显,一方面占内存比较多,另一方面计算起来会很慢,所以不推荐。

猜你喜欢

转载自blog.csdn.net/weixin_44618906/article/details/107435076