pytorch踩坑日记

昨天使用pytorch写一个程序,程序写完之后却一直不能正确运行,今天定位到了代码的问题所在:
我的代码其中有一处逻辑是这样的:

……
get a # 这里的a就是我想反向求导更新的参数
b=torch.nonzero(a)  # 得到a里面所有不为0的下标

for i,j in b:
	feature_i=feature[i]
	feature[j]=feature[j]
	……
	get c  # 通过b中的元素下标得到一个c
	
c->loss  # c经过一系列操作得到了loss

之后loss反向求导发现a的梯度一直是0,这是因为我只用到了a中不为0元素的下标,根本就没有用到a这个矩阵,此时的计算图应该是这样的:

b->c->loss

因为a->b根本没有产生梯度关系,所以计算图肯定不能反向传导到a。

以后写代码要思考计算图的建立

Guess you like

Origin blog.csdn.net/cobracanary/article/details/121053550