Explicación detallada de .detach(), .data y .detach_()

Tabla de contenido

1. .separar()

1.1 Ejemplo 1

1.2 Ejemplo 2

1.3 Ejemplo 3

1.4 Comprobación de corrección in situ

dos, .datos

3. .separar_()


Cuando estamos entrenando la red, es posible que deseemos mantener algunos de los parámetros de la red sin cambios y solo ajustar algunos de los parámetros, o solo entrenar parte de la red secundaria y no dejar que su gradiente afecte el gradiente de la red principal. esta vez, lo haremos La función detach() debe usarse para cortar la propagación hacia atrás de algunas ramas.

1. .separar()

Devuelve uno nuevo Variable, que está separado del gráfico de cálculo actual, pero aún apunta a la ubicación de almacenamiento de la variable original. La única diferencia es que require_grad es falso, y el obtenido nunca necesita calcular su Variablegradiente y no tiene grad. .

Incluso si require_grad se vuelve a establecer en verdadero más tarde, no tendrá una graduación de gradiente.

De esta manera, cuando continuamos usando esta nueva Variable进行计算,后面当我们进行retropropagación, la llamada a detach() Variablese detendrá y no podremos continuar propagando hacia adelante.

El código fuente es:

def detach(self):
        """Returns a new Variable, detached from the current graph.
        Result will never require gradient. If the input is volatile, the output
        will be volatile too.
        .. note::
          Returned Variable uses the same data tensor, as the original one, and
          in-place modifications on either of them will be seen, and may trigger
          errors in correctness checks.
        """
        result = NoGrad()(self)  # this is needed, because it merges version counters
        result._grad_fn = None     return result

可见函数进行的操作有:

  • Establecer grad_fn en Ninguno
  • 将Variablederequires_grad设置为False

Si se ingresa  volatile=True(即不需要保存记录,当只需要结果而不需要更新参数时,这么设置可以加快运算速度), entonces lo que se devuelve Variable volatile=True. ( volatileen desuso)

注意:

El devuelto comparte lo mismo que Variableel original . Las modificaciones se reflejarán en ambos (ya que son compartidos ), lo que puede causar errores al intentar llamarlos hacia atrás().Variabledata tensorin-place函数Variabledata tensor

Consulte el siguiente ejemplo para comprender.

1.1 Ejemplo 1

Por ejemplo, un ejemplo normal es:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()

out.sum().backward()
print(a.grad)

devolver:

(deeplearning) userdeMBP:pytorch user$ python test.py 
None
tensor([0.1966, 0.1050, 0.0452])

Backward() no se ve afectado cuando se usa detach() pero no se realizan cambios:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)

#这时候没有对c进行更改,所以并不会影响backward()
out.sum().backward()
print(a.grad)

devolver:

(deeplearning) userdeMBP:pytorch user$ python test.py 
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0.1966, 0.1050, 0.0452])

Se puede ver que la diferencia entre c y out es que c no tiene gradiente y out tiene gradiente.

1.2 Ejemplo 2

Si c se usa aquí para la operación sum() y hacia atrás(), se informará un error:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)

#使用新生成的Variable进行反向传播
c.sum().backward()
print(a.grad)

devolver:

(deeplearning) userdeMBP:pytorch user$ python test.py 
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
Traceback (most recent call last):
  File "test.py", line 13, in <module>
    c.sum().backward()
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

1.3 Ejemplo 3

Si se realiza un cambio en c en este momento, este cambio será rastreado por autograd, y también se informará de un error al realizar la función de retroceso() en out.sum(), porque el gradiente obtenido al realizar la función de retroceso() en el valor en este momento está mal:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时会影响out的值
print(c)
print(out)

#这时候对c进行更改,所以会影响backward(),这时候就不能进行backward(),会报错
out.sum().backward()
print(a.grad)

devolver:

(deeplearning) userdeMBP:pytorch user$ python test.py 
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
Traceback (most recent call last):
  File "test.py", line 16, in <module>
    out.sum().backward()
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

 El principio de la realización del contenido anterior es:

1.4 Comprobación de corrección in situ

Todo Variableserá grabado y utilizado en ellos  in-place operations. Si pytorchse detecta que variableun Functionarchivo se ha guardado para su uso backward, pero se ha in-place operationsmodificado posteriormente. Cuando esto sucede, en backward, pytorchse informará un error. Este mecanismo asegura que si lo usa in-place operations, pero backwardno se reporta ningún error durante el proceso, entonces el cálculo del gradiente es correcto.

dos, .datos

Si la operación anterior usa .data, el efecto será diferente:

La diferencia aquí es que la modificación de .data no será rastreada por autograd, por lo que no informará de un error al hacer retroceder(), y obtendrá un valor de retroceso incorrecto .

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)


c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时也会影响out的值
print(c)
print(out)

#这里的不同在于.data的修改不会被autograd追踪,这样当进行backward()时它不会报错,回得到一个错误的backward值
out.sum().backward()
print(a.grad)

devolver:

(deeplearning) userdeMBP:pytorch user$ python test.py 
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
tensor([0., 0., 0.])

El siguiente resultado es correcto porque el resultado de sum() cambia y el valor intermedio a.sigmoid() no se ve afectado, por lo que no tiene efecto en el gradiente:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid().sum() #但是如果sum写在这里,而不是写在backward()前,得到的结果是正确的
print(out)


c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时也会影响out的值
print(c)
print(out)

#没有写在这里
out.backward()
print(a.grad)

 devolver:

(deeplearning) userdeMBP:pytorch user$ python test.py 
None
tensor(2.5644, grad_fn=<SumBackward0>)
tensor(2.5644)
tensor(0.)
tensor(0., grad_fn=<SumBackward0>)
tensor([0.1966, 0.1050, 0.0452])

3. .separar_()

separe a Variabledel gráfico que lo creó y configúrelo como una hojavariable

De hecho, es equivalente a que la relación entre variables es x -> m -> y, la variable hoja aquí es x, pero en este momento la operación .detach_() se realiza en m, que en realidad son dos operaciones:

  • Establezca el valor de m's grad_fn en Ninguno, de modo que m ya no se asocie con el nodo anterior x, y la relación aquí se convertirá en x, m -> y, y m se convertirá en un nodo hoja en este momento
  • Luego, require_grad de m se establecerá en False, de modo que no se buscará el gradiente de m cuando se realice back() en y

Mirándolo de esta manera, detach() es en realidad muy similar a detach_() La diferencia entre los dos es que detach_() es un cambio en sí mismo, y detach() genera una nueva variable.

Por ejemplo, si detach() se realiza en m en x -> m -> y, si luego se arrepiente, aún desea operar en el gráfico de cálculo original. Pero si se ejecuta detach_(), entonces el gráfico de cálculo original también ha cambiado y no puede retractarse de su palabra.

Supongo que te gusta

Origin blog.csdn.net/weixin_45684362/article/details/131987689
Recomendado
Clasificación