深度学习总结:GAN,3种方式实现fixedGtrainD,fixedDtrainG, retain, detach

retain和detach

pytorch有两个功能:retain和detach:
retain:意思是保持原来graph,可以还在原图上进行forward pass,下次计算backward还是在原图上计算;
detach:阻断,意思是backward pass到这儿就停止了

这两个东西可以用于实现fixedGtrainD,fixedDtrainG。

先更新D,再更新G,这个也是GAN论文的实现方式

先forward整个网络;
再backward整个网络的梯度但是只更新D的参数(相当于G的部分的梯度白算了),这时还需要retain graph一下;
用fake data forward一下D,再backward整个网络的梯度但是只更新G的参数;
至此完成了一轮G和D的对抗。
在这里插入图片描述

先更新G,再更新D,实际上他两谁先谁后都一样,都是相互对抗:

先forward整个网络;
再backward整个网络的梯度但是只更新G的参数(相当于D的部分的梯度只是用来传递G的梯度);
先forward整个网络,这时还需要detach一下G;
再backward到G就停止了,更新D的参数;
至此完成了一轮G和D的对抗。
在这里插入图片描述

第三种是第一种的改进, 先更新D,再更新G,计算最少,还没见到别人实现,估计知名框架这么实现,每具体检查过:

先forward整个网络,同时detach一下G,retain一下graph;
再backward到G停止了,但是只更新D的参数,同时undetach一下G,retain一下graph;用fake data forward一下D;
再backward整个网络,更新G的参数(相当于D的部分的梯度只是用来传递G的梯度);
至此完成了一轮G和D的对抗。
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_40759186/article/details/87534954
今日推荐