Use register_forward_hook() to accurately locate the input and output of a certain layer of the model

I accidentally read in the paper that some methods use the hidden layer in the middle of the model as a classifier to compare the performance of the last layer of the model as a classifier, so I thought about how to easily and quickly pull the output of a certain layer of the model. Taking it out, I found that there is a ready-made hook function that can do this.

hook

A hook is a hook, used to hook out the input, output or other information of a certain layer in the network. If you want to know the detailed information of a certain layer in the network, you don't need to write a separate print when defining the network, just write a hook function directly. That’s it.

register_forward_hook

The source code states that the hook can only be used before the forward() function is run. It is useless to write it after the forward function is run. This means that if you want to run the hook, you must first write the hook function and then instantiate the network.

def register_forward_hook(self, hook):
        r'''Registers a forward hook on the module.
        The hook will be called every time after :func:`forward` has computed an output.
        It should have the following signature::
            hook(module, input, output) -> None or modified output
        The hook can modify the output. It can modify the input inplace but
        it will not have effect on forward since this is called after
        :func:`forward` is called.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
       '''
        handle = hooks.RemovableHandle(self._forward_hooks)
        self._forward_hooks[handle.id] = hook
        return handle

question

Sometimes there are multiple Linear layers in the model, but all modules of the same type extracted by net.children() also have the same names. Therefore, it is judged based on the input and output dimensions of the current Linear layer, and is accurately locked to this layer, and other modules are Also still applicable

code part

import torch
import torch.nn as nn
class TestForHook(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1=nn.Linear(2,2)
        self.linear_2 = nn.Linear(in_features=2, out_features=1)
        self.add_module('linear1', self.linear_1)
        self.add_module('linear2', self.linear_2)
        self.relu = nn.ReLU()
        self.relu6 = nn.ReLU6()
        self.initialize()
    def forward(self,x):
        linear_1=self.linear_1(x)
        linear_2=self.linear_2(linear_1)
        relu=self.relu(linear_2)
        relu_6 = self.relu6(relu)
        layers_in=(x,linear_1,linear_2)
        layers_out=(linear_1,linear_2,relu)
        return relu_6,layers_in,layers_out
    def initialize(self):
        '''定义特殊的初始化,用于验证hook作用时是否获取了权重'''
        self.linear_1.weight=torch.nn.Parameter(torch.FloatTensor([[1,1],[1,1]]))
        self.linear_1.bias=torch.nn.Parameter(torch.FloatTensor([1,1]))
        self.linear_2.weight=torch.nn.Parameter(torch.FloatTensor([[1,1]]))
        self.linear_2.bias=torch.nn.Parameter(torch.FloatTensor([1]))
        return True
#定义hook函数用来决定勾出来的网络信息用来做什么
#定义用于获取网络各层输入输出的tensor容器
#定义nodule_name用于记录相应的module名字
module_name=[]
features_in_hook=[]
features_out_hook=[]
#hook函数需要3个参数,这三个参数是系统传给hook函数的,自己不能修改这三个参数
#hook函数负责将获取的输入输出添加到feature列表中 并提供相应的module名字
def hook(module,input,output):
    print("hooker working")
    module_name.append(module.__class__)
    features_in_hook.append(input)
    features_out_hook.append(output)
    return None
#对需要的层register hook
#register hook必须在forward()函数被执行之前,也就是实例化网络之前,下面的代码对网络除了ReLU以外的层都register了
#也可以选定其中的某些层进行register
net=TestForHook()
net_children=net.children()
#不同Linear层的参数in_features和out_features通常不同,可以用这些信息来判断
for child in net_children:
    if isinstance(child, nn.Linear) and child.in_features == 2 and child.out_features == 2:
    # if isinstance(child, nn.Linear):
        child.register_forward_hook(hook=hook)
#测试forward()提供的输入输出特征
x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])
out, features_in_forward, features_out_forward = net(x)
# print("*"*5+"forward return features"+"*"*5)
# print(features_in_forward)
# print(features_out_forward)
# print("*"*5+"forward return features"+"*"*5)
#hook通过list结构进行记录,所以可以直接print
print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)

Guess you like

Origin blog.csdn.net/gary101818/article/details/132453662