Say goodbye to code copy and paste, fool-like extraction of PyTorch middle layer features

Content introduction: Feature extraction is a method often used in image processing, and its effect has a crucial impact on the generalization ability of the model.

Feature extraction is widely used in machine learning, pattern recognition and image processing.

It starts from an initial set of measurement data and constructs information-providing and non-redundant derived values, that is, characteristic values, so as to promote subsequent learning and generalization steps.

In the process of using PyTorch for model training, it is often necessary to extract the features of the middle layer of the model. There are 3 methods to solve this problem.

3 methods for feature extraction in the middle layer

1. Transfer with the attributes of the model class

Method: modify the forward function, assign feature to the self variable by adding a line of code , that is, self.feature_map = feature , and then print the output.

Remarks: Applicable to the situation where only the features of the middle layer are extracted and the gradient is not required.

Code example:

# Define a Convolutional Neural Network
class 
Net(nn.Module):
    
    def __init__(self, kernel_size=5, n_filters=16, n_layers=3):
        xxx
    def forward(self, x): 
        x = self.body(self.head(x))


        self.featuremap1 = x.detach() # 核心代码


        return F.relu(self.fc(x))


model_ft = Net()
train_model(model_ft)
feature_output1 = model_ft.featuremap1.transpose(1,0).cpu()

2. With the help of hook mechanism

A hook is a callable object, which can insert services without modifying the main code. There are three types of hooks in PyTorch:

torch.autograd.Variable.register_hook

torch.nn.Module.register_backward_hook

torch.nn.Module.register_forward_hook

The first one is for the Variable object, and the latter two are for the nn.Module object.

Method: Use the forward_hook function on Module in the calling phase to obtain the required gradient or feature.

Remarks: It is more complex and well-functioning, and requires a certain degree of understanding of PyTorch.

3. With the help of torchextractor

torchextractor is an independent Python package with an extractor similar to nn.Module. You only need to provide the module name to extract features from the middle layer in PyTorch.

Compared with the use of forward_hook for feature extraction in the middle layer, torchextractor is more like a wrapper, unlike torchvision IntermediateLayerGetter, which has so many assumptions .

In terms of functionality, torchextractor's main advantage lies in its support for nested modules, custom cache operations, and compatibility with ONNX.

torchextractor greatly simplifies the process of feature extraction in PyTorch, which avoids the pasting and copying of a large amount of code, and does not need to rewrite the forward function. It is more friendly to beginners and more usable.

torchextractor hands-on practice

installation

pip install torchextractor  # stable
pip install git+https://github.com/antoinebrl/torchextractor.git  # latest

Claim

Python 3.6 and above

Torch 1.4.0 and above

usage

import torch
import torchvision
import torchextractor as tx


model = torchvision.models.resnet18(pretrained=True)
model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"])
dummy_input = torch.rand(7, 3, 224, 224)
model_output, features = model(dummy_input)
feature_shapes = {name: f.shape for name, f in features.items()}
print(feature_shapes)


# {
#   'layer1': torch.Size([1, 64, 56, 56]),
#   'layer2': torch.Size([1, 128, 28, 28]),
#   'layer3': torch.Size([1, 256, 14, 14]),
#   'layer4': torch.Size([1, 512, 7, 7]),
# }

For complete documentation, please view:

https://github.com/antoinebrl/torchextractor

The above are the three methods of feature extraction for the middle layer summarized in this issue. If you have a better solution or other Pytorch-related issues that you want to know, please leave a message or send a private message below.

reference:

https://www.reddit.com/r/MachineLearning/comments/m2vwf9/p_pytorch_intermediate_feature_extraction/

https://www.zhihu.com/question/68384370

Guess you like

Origin blog.csdn.net/HyperAI/article/details/114874876