Tabla de contenido
1.4 Comprobación de corrección in situ
Cuando estamos entrenando la red, es posible que deseemos mantener algunos de los parámetros de la red sin cambios y solo ajustar algunos de los parámetros, o solo entrenar parte de la red secundaria y no dejar que su gradiente afecte el gradiente de la red principal. esta vez, lo haremos La función detach() debe usarse para cortar la propagación hacia atrás de algunas ramas.
1. .separar()
Devuelve uno nuevo Variable
, que está separado del gráfico de cálculo actual, pero aún apunta a la ubicación de almacenamiento de la variable original. La única diferencia es que require_grad es falso, y el obtenido nunca necesita calcular su Variable
gradiente y no tiene grad. .
Incluso si require_grad se vuelve a establecer en verdadero más tarde, no tendrá una graduación de gradiente.
De esta manera, cuando continuamos usando esta nueva Variable进行计算,后面当我们进行
retropropagación, la llamada a detach() Variable
se detendrá y no podremos continuar propagando hacia adelante.
El código fuente es:
def detach(self):
"""Returns a new Variable, detached from the current graph.
Result will never require gradient. If the input is volatile, the output
will be volatile too.
.. note::
Returned Variable uses the same data tensor, as the original one, and
in-place modifications on either of them will be seen, and may trigger
errors in correctness checks.
"""
result = NoGrad()(self) # this is needed, because it merges version counters
result._grad_fn = None return result
可见函数进行的操作有:
- Establecer grad_fn en Ninguno
将Variable
derequires_grad设置为False
Si se ingresa volatile=True(即不需要保存记录,当只需要结果而不需要更新参数时,这么设置可以加快运算速度)
, entonces lo que se devuelve Variable
. ( volatile
=True
en desuso)volatile
注意:
El devuelto comparte lo mismo que Variable
el original . Las modificaciones se reflejarán en ambos (ya que son compartidos ), lo que puede causar errores al intentar llamarlos hacia atrás().Variable
data tensor
in-place函数
Variable
data tensor
Consulte el siguiente ejemplo para comprender.
1.1 Ejemplo 1
Por ejemplo, un ejemplo normal es:
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
out.sum().backward()
print(a.grad)
devolver:
(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.1966, 0.1050, 0.0452])
Backward() no se ve afectado cuando se usa detach() pero no se realizan cambios:
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)
#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
#这时候没有对c进行更改,所以并不会影响backward()
out.sum().backward()
print(a.grad)
devolver:
(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0.1966, 0.1050, 0.0452])
Se puede ver que la diferencia entre c y out es que c no tiene gradiente y out tiene gradiente.
1.2 Ejemplo 2
Si c se usa aquí para la operación sum() y hacia atrás(), se informará un error:
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)
#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
#使用新生成的Variable进行反向传播
c.sum().backward()
print(a.grad)
devolver:
(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
Traceback (most recent call last):
File "test.py", line 13, in <module>
c.sum().backward()
File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
1.3 Ejemplo 3
Si se realiza un cambio en c en este momento, este cambio será rastreado por autograd, y también se informará de un error al realizar la función de retroceso() en out.sum(), porque el gradiente obtenido al realizar la función de retroceso() en el valor en este momento está mal:
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)
#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
c.zero_() #使用in place函数对其进行修改
#会发现c的修改同时会影响out的值
print(c)
print(out)
#这时候对c进行更改,所以会影响backward(),这时候就不能进行backward(),会报错
out.sum().backward()
print(a.grad)
devolver:
(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
Traceback (most recent call last):
File "test.py", line 16, in <module>
out.sum().backward()
File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
El principio de la realización del contenido anterior es:
1.4 Comprobación de corrección in situ
Todo Variable
será grabado y utilizado en ellos in-place operations
. Si pytorch
se detecta que variable
un Function
archivo se ha guardado para su uso backward
, pero se ha in-place operations
modificado posteriormente. Cuando esto sucede, en backward
, pytorch
se informará un error. Este mecanismo asegura que si lo usa in-place operations
, pero backward
no se reporta ningún error durante el proceso, entonces el cálculo del gradiente es correcto.
dos, .datos
Si la operación anterior usa .data, el efecto será diferente:
La diferencia aquí es que la modificación de .data no será rastreada por autograd, por lo que no informará de un error al hacer retroceder(), y obtendrá un valor de retroceso incorrecto .
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)
c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改
#会发现c的修改同时也会影响out的值
print(c)
print(out)
#这里的不同在于.data的修改不会被autograd追踪,这样当进行backward()时它不会报错,回得到一个错误的backward值
out.sum().backward()
print(a.grad)
devolver:
(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
tensor([0., 0., 0.])
El siguiente resultado es correcto porque el resultado de sum() cambia y el valor intermedio a.sigmoid() no se ve afectado, por lo que no tiene efecto en el gradiente:
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid().sum() #但是如果sum写在这里,而不是写在backward()前,得到的结果是正确的
print(out)
c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改
#会发现c的修改同时也会影响out的值
print(c)
print(out)
#没有写在这里
out.backward()
print(a.grad)
devolver:
(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor(2.5644, grad_fn=<SumBackward0>)
tensor(2.5644)
tensor(0.)
tensor(0., grad_fn=<SumBackward0>)
tensor([0.1966, 0.1050, 0.0452])
3. .separar_()
separe a Variable
del gráfico que lo creó y configúrelo como una hojavariable
De hecho, es equivalente a que la relación entre variables es x -> m -> y, la variable hoja aquí es x, pero en este momento la operación .detach_() se realiza en m, que en realidad son dos operaciones:
- Establezca el valor de m's grad_fn en Ninguno, de modo que m ya no se asocie con el nodo anterior x, y la relación aquí se convertirá en x, m -> y, y m se convertirá en un nodo hoja en este momento
- Luego, require_grad de m se establecerá en False, de modo que no se buscará el gradiente de m cuando se realice back() en y
Mirándolo de esta manera, detach() es en realidad muy similar a detach_() La diferencia entre los dos es que detach_() es un cambio en sí mismo, y detach() genera una nueva variable.
Por ejemplo, si detach() se realiza en m en x -> m -> y, si luego se arrepiente, aún desea operar en el gráfico de cálculo original. Pero si se ejecuta detach_(), entonces el gráfico de cálculo original también ha cambiado y no puede retractarse de su palabra.