Loading mechanism pytorch version of pre-trained model

When doing two-dimensional target detection, we will modify some parts of the neural network, such as adding CBAM, or modifying FPN and so on. However, when the modified network is trained, the process of loading the pre-training weights will not report an error, and even the performance will not increase but decrease after modifying the network. The knowledge points contained in this are summarized in this article.

The code for importing pretrained weights in the network in pytorch is very simple:

net = model()
net.to(device)
net.load_state_dict(torch.load('params.pth'))

in:

  • torch.load('params.pth') just loads the model parameters and does not put the parameters into the network . Here, the parameters are loaded into memory in the form of key-value pairs.
  • The load_state_dict method is to put the parameter dictionary into the network.

When we have made modifications to the network, we just need to do the following to load the previous pretrained weights into the new network. In mmdetection or some Daniel's open source code, the pre-training weights can still be loaded after modifying the network for the following reasons

Guess you like

Origin blog.csdn.net/qq_42308217/article/details/123140481