pytorch .detach() .detach_() 和 .data

When we are training the network, we may want to keep some of the network parameters unchanged, and only adjust some of the parameters; or only train part of the branch network without letting its gradient affect the gradient of the main network. At this time, we will Need to use the detach() function to cut off the back propagation of some branches.


一、detach()[source]

Return a new one Variablethat is separated from the current calculation graph , but still points to the storage location of the original variable. The only difference is that requires_grad is false. The obtained one Variablenever needs to calculate its gradient and does not have grad.

Even if it resets its requirements_grad to true later, it will not have a gradient grad

In this way, we will continue to use this new Variable for calculations. Later, when we perform backpropagation, the call to detach() Variablewill stop and cannot continue to propagate forward.

The source code is:

def detach(self):
        """Returns a new Variable, detached from the current graph.
        Result will never require gradient. If the input is volatile, the output
        will be volatile too.
        .. note::
          Returned Variable uses the same data tensor, as the original one, and
          in-place modifications on either of them will be seen, and may trigger
          errors in correctness checks.
        """
        result = NoGrad()(self)  # this is needed, because it merges version counters
        result._grad_fn = None
     return result

The operations performed by the visible function are:

  • Set grad_fn to None
  • 将Variableofrequires_grad设置为False

If you enter  volatile=True(即不需要保存记录,当只需要结果而不需要更新参数时这么设置来加快运算速度), then return Variable volatile=True. ( volatileDeprecated)

note:

The returned one is the same as the Variableoriginal Variableone data tensor. in-place函数Modifications will be Variablereflected on both at the same time (because they are shared data tensor), which may cause errors when calling backward() on them.

For example:

A normal example:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()

out.sum().backward()
print(a.grad)

return:

None
tensor([0.1966, 0.1050, 0.0452])

When detach() is used but no changes are made, backward() will not be affected:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)

#这时候没有对c进行更改,所以并不会影响backward()
out.sum().backward()
print(a.grad)

return:

None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0.1966, 0.1050, 0.0452])

It can be seen that the difference between c and out is that c has no gradient and out has gradient

 

If c is used here for sum() operation and backward(), an error will be reported:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)

#使用新生成的Variable进行反向传播
c.sum().backward()
print(a.grad)

return:

None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
Traceback (most recent call last):
  File "test.py", line 13, in <module>
    c.sum().backward()
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

If c is changed at this time, this change will be tracked by autograd, and an error will be reported when backward() is performed on out.sum(), because the gradient obtained by backward() on the value at this time is wrong:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时会影响out的值
print(c)
print(out)

#这时候对c进行更改,所以会影响backward(),这时候就不能进行backward(),会报错
out.sum().backward()
print(a.grad)

return:

None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
Traceback (most recent call last):
  File "test.py", line 16, in <module>
    out.sum().backward()
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

二、data

If the above operation uses .data, the effect will be different:

The difference here is that the modification of .data will not be tracked by autograd, so that it will not report an error when performing backward() and get an incorrect backward value.

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)


c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时也会影响out的值
print(c)
print(out)

#这里的不同在于.data的修改不会被autograd追踪,这样当进行backward()时它不会报错,回得到一个错误的backward值
out.sum().backward()
print(a.grad)

return:

None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
tensor([0., 0., 0.])

The principle of the above content is: In-place correctness check

All Variablewill be recorded and used on them  in-place operations. If it is pytorchdetected that it has been saved for use variablein one , but then it has been modified. When this happens , there will be an error. This mechanism ensures that if you use it , but no error is reported in the process, then the gradient calculation is correct.Functionbackwardin-place operationsbackwardpytorchin-place operationsbackward

The following result is correct because the change is the result of sum(), the intermediate value a.sigmoid() is not affected, so it has no effect on the gradient:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid().sum() #但是如果sum写在这里,而不是写在backward()前,得到的结果是正确的
print(out)


c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时也会影响out的值
print(c)
print(out)

#没有写在这里
out.backward()
print(a.grad)

return:

None
tensor(2.5644, grad_fn=<SumBackward0>)
tensor(2.5644)
tensor(0.)
tensor(0., grad_fn=<SumBackward0>)
tensor([0.1966, 0.1050, 0.0452])

三、 detach_()[source]

VariableSeparate one from the graph that created it and set it as a leafvariable

In fact, it is equivalent to the relationship between the variables is originally x -> m -> y, where the leaf variable is x, but at this time, the .detach_() operation is performed on m. In fact, two operations are performed:

  • Set the value of m's grad_fn to None, so that m will no longer be associated with the previous node x, the relationship here will become x, m -> y, at this time m becomes a leaf node
  • Then the requirements_grad of m will be set to False, so that the gradient of m will not be calculated when y is backward()

In this way, detach() and detach_() are very similar. The difference between the two is that detach_() is a change to itself, while detach() generates a new variable.

For example, in x -> m -> y, if m is detach(), it is still possible to operate on the original calculation graph if you want to regret later.

But if detach_() is performed, then the original calculation graph has also changed, and you cannot go back.

Transfer from: https://blog.csdn.net/weixin_34363171/article/details/94236818

 

Guess you like

Origin blog.csdn.net/Answer3664/article/details/104314030