问题解决:Pytorch :Trying to backward through the graph a second time, but the buffers。。

最近在学习Pytorch,刚用Pytorch重写了之前用Tensorlfow写的论文代码。
首次运行就碰到了一个bug:
Pytorch - RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
刚开始按照这个错误提示,设置loss.backward(retain_graph=True),虽然解决了这个问题,但是随着训练的继续,报错OOM。很尴尬。。。
查了stackoverflow上的方法,最终解决了问题。

我原来的代码是:

 for side in outputs:
     loss += Loss(side, label)

 loss.backward(retain_graph=True)

很显然,一旦调用loss.backward(), 就相当于调用了多次的Loss(side, label)的.backward()方法,而Pytorch的机制是每次调用.backward()都会free掉所有buffers,所以它提示,让retain_graph。然而当retain后,buffers就不会被free了,所以会OOM。
最后的解决办法就是, 分开写:

side0 = Loss(output[0], label)
side1 = Loss(output[1], label)
side2 = Loss(output[2], label)
side3 = Loss(output[3], label)
side4 = Loss(output[4], label)
side5 = Loss(output[5], label)
loss = side0 + side1 + side2 + side3 + side4 + side5

作者:Mundane_World
来源:CSDN
原文:https://blog.csdn.net/Mundane_World/article/details/81038274
版权声明:本文为博主原创文章,转载请附上博文链接!

猜你喜欢

转载自blog.csdn.net/Harpoon_fly/article/details/84372595