Four hook functions of pytorch

  Training a neural network model sometimes requires observing the input and output of the internal modules of the model, or hoping to adjust the output of the intermediate module without modifying the original module structure. Pytorch can use hook callback functions to achieve this function. Mainly use four hook registration functions: register_forward_hook, register_forward_pre_hook, register_full_backward_hook, register_full_backward_pre_hook. These four functions can be called by any module that inherits nn.Module. The hook function is passed in and registered, so that the hook function is called at the corresponding stage of executing the module to implement the required functions.

register_forward_hook(self, hook, *, prepend, with_kwargs)

  Register a callback function for the module that is executed after the forward propagation of the module.

  hook(module, args, output): The callback function object to be executed, module is the current module reference, args is the forward propagation input of the current module, and output is the forward propagation output of the current module. Modified output can be returned to modify the forward propagation output of this module.

  prepend: Put the hook function at the front of the callback function list so that it is executed first, otherwise it is placed at the end of the queue.

  with_kwargs: Whether the hook function passes in keyword parameters. If it is True, the hook can add additional key parameters.

  The register_forward_hook registration function itself returns a handle handle, and handle.remove() can be executed to remove the registered hook function.

register_forward_pre_hook(self, hook, *, prepend, with_kwargs)

  Register a callback function for the module that is executed before the module is forwarded.

  Hook(module, args): args is the forward propagation input of the module. Modified args can be returned to modify the module's forward pass input.

  Other parameters and characteristics are consistent with the previous ones.

register_full_backward_hook(self, hook, prepend)

  Register a callback function for the module that is executed after the module is backpropagated.

  hook(module, grad_input, grad_output): grad_input and grad_output are the gradients of the forward propagation input and output of the module respectively. A modified grad_input can be returned to modify the gradient of the module's forward propagation input.

register_full_backward_pre_hook(self, hook, prepend)

  Register a callback function for the module that is executed before backpropagation of the module.

  hook(module, grad_output): grad_output is the gradient of the forward propagation output of the module. This gradient can be modified by returning a modified grad_output.

Guess you like

Origin blog.csdn.net/qq_37189298/article/details/133698900