Pytorch-detach()用法

目的:

神经网络的训练有时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整。或者训练部分分支网络,并不让其梯度对主网络的梯度造成影响.这时候我们就需要使用detach()函数来切断一些分支的反向传播.

1 tensor.detach()

返回一个新的tensor,从当前计算图中分离下来。但是仍指向原变量的存放位置,不同之处只是requirse_grad为false.得到的这个tensir永远不需要计算器梯度,不具有grad.

即使之后重新将它的requires_grad置为true,它也不会具有梯度grad.这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播.

注意:

使用detach返回的tensor和原始的tensor共同一个内存,即一个修改另一个也会跟着改变

比如正常的例子是:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a)
print(a.grad)
out = a.sigmoid()

out.sum().backward()
print(a.grad)

输出

tensor([1., 2., 3.], requires_grad=True)
None
tensor([0.1966, 0.1050, 0.0452])

1.1 当使用detach()分离tensor但是没有更改这个tensor时,并不会影响backward():

import torch
 
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)
 
#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
 
#这时候没有对c进行更改,所以并不会影响backward()
out.sum().backward()
print(a.grad)
 
'''返回:
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0.1966, 0.1050, 0.0452])
'''

参考

1.pytorch的两个函数 .detach() .detach_() 的作用和区别

猜你喜欢

转载自blog.csdn.net/qq_31244453/article/details/112473947