[aprendizaje de pytorch] uso y puntos de atención de cuatro métodos de enlace (register_forward_hook, etc.)

  Para ahorrar memoria de video (memoria), Pytorch no guarda variables intermedias durante el proceso de cálculo, incluidos mapas de características de capas intermedias y gradientes de tensores que no son hojas. En ocasiones es necesario visualizar o modificar estas variables intermedias al analizar la red, en este momento es necesario registrar un gancho (hook) para exportar las variables intermedias requeridas. Hay muchas introducciones a esto en Internet, pero las leí y algunas de ellas son inexactas o difíciles de entender. Las resumiré aquí y daré el uso real y los puntos de atención.
Hay cuatro métodos de enlace:
torch.Tensor.register_hook()
torch.nn.Module.register_forward_hook()
torch.nn.Module.register_backward_hook()
torch.nn.Module.register_forward_pre_hook().

1, antorcha.Tensor.register_hook(gancho)

  Se utiliza para exportar el gradiente del tensor especificado o modificar el valor del gradiente.

import torch
def grad_hook(grad):
    grad *= 2
x = torch.tensor([2., 2., 2., 2.], requires_grad=True)
y = torch.pow(x, 2)
z = torch.mean(y)
h = x.register_hook(grad_hook)
z.backward()
print(x.grad)
h.remove()    # removes the hook
>>> tensor([2., 2., 2., 2.])

Nota: (1) El código anterior es válido, pero no será válido si se escribe como grad = grad * 2, porque no hay ninguna operación local en grad en este momento y el nuevo valor de grad no se pasa al especificado degradado. Para estar seguro, es mejor escribir return grad en la declaración def. Ahora mismo:

def grad_hook(grad):
    grad = grad * 2
    return grad

(2) El gancho se puede cancelar con el método remove(). Tenga en cuenta que eliminar () debe estar después de atrás (), porque pytorch solo comienza a calcular el gradiente cuando se ejecuta la instrucción hacia atrás (), y solo "registra" un gancho de graduación cuando x.register_hook (grad_hook), en este momento. no se calcula, y la ejecución de eliminar cancelará el gancho, y luego el gancho no funcionará cuando esté hacia atrás ().
(3) Si la función de enlace está definida en la clase, el parámetro de entrada primero debe agregarse a sí mismo, es decir

def grad_hook(self, grad):
    ...

2, torch.nn.Module.register_forward_hook (módulo, entrada, salida)

  Se utiliza para exportar los tensores de entrada y salida del submódulo especificado (que puede ser una capa, módulo, etc. nn.Tipo de módulo), pero solo se puede modificar la salida. A menudo se usa para exportar o modificar características de convolución. mapa.

inps, outs = [],[]
def layer_hook(module, inp, out):
    inps.append(inp[0].data.cpu().numpy())
    outs.append(out.data.cpu().numpy())

hook = net.layer1.register_forward_hook(layer_hook)
output = net(input)
hook.remove()

Nota: (1) Debido a que el módulo puede tener múltiples entradas, la entrada es de tipo tupla y el tensor que contiene debe extraerse antes de la operación; la salida es de tipo tensor y se puede usar directamente.
   (2) No lo coloque en la memoria de video después de exportarlo, a menos que tenga A100.
   (3) Solo se puede modificar el valor de salida, y el valor de entrada no se puede modificar (no se puede devolver y la modificación local tampoco es válida). Al modificar, es mejor devolver en forma de retorno, tales como:

def layer_hook(self, module, inp, out):
    out = self.lam * out + (1 - self.lam) * out[self.indices]
    return out

  Este código se utiliza en la combinación de múltiples para combinar las características de la capa intermedia para lograr la mejora de los datos, donde self.lam es un valor de probabilidad [0,1] y self.indices es el número de serie después de la mezcla.

3, torch.nn.Module.register_forward_pre_hook (módulo, en)

  Se utiliza para exportar o modificar los tensores de entrada del submódulo especificado.

def pre_hook(module, inp):
    inp0 = inp[0]
    inp0 = inp0 * 2
    inp = tuple([inp0])
    return inp

hook = net.layer1.register_forward_pre_hook(pre_hook)
output = net(input)
hook.remove()

Nota: (1) El valor inp es un tipo de tupla, por lo que primero debe extraer el tensor que contiene, luego realizar otras operaciones y luego convertirlo en una tupla y devolverlo.
(2) Esta oración se llamará solo cuando se ejecute salida = net (entrada), y remove() se puede usar para cancelar el enlace después de llamar.

4, torch.nn.Module.register_backward_hook (módulo, grad_in, grad_out)

  Se utiliza para exportar el gradiente de los tensores de entrada y salida del submódulo especificado, pero solo se puede modificar el gradiente del tensor de entrada (es decir, solo se puede devolver gin) y el gradiente del tensor de salida no se puede modificar. .

gouts = []
def backward_hook(module, gin, gout):
    print(len(gin),len(gout))
    gouts.append(gout[0].data.cpu().numpy())
    gin0,gin1,gin2 = gin
    gin1 = gin1*2
    gin2 = gin2*3
    gin = tuple([gin0,gin1,gin2])
    return gin

hook = net.layer1.register_backward_hook(backward_hook)
loss.backward()
hook.remove()

Nota:
(1) grad_in y grad_out son tuplas, que deben desenvolverse primero y luego volver a colocarse en la tupla para regresar después de realizar operaciones al modificar.
(2) Esta función de enlace se llama en la declaración hacia atrás (), por lo que remove() debe colocarse después de hacia atrás () para cancelar el enlace.

Supongo que te gusta

Origin blog.csdn.net/Brikie/article/details/114255743
Recomendado
Clasificación