[Pytorch study notes] 12. The method of modifying the weight parameters of the pre-training model (used to use the pre-training model for single-channel grayscale images)


When we train single-channel images, that is, grayscale images (such as medical imaging data), we often use pre-trained models for training.
However, the general pre-training model is pre-trained on the ImageNet dataset, and the training object is a three-channel color image.
This requires modifying the parameters of the model to change the parameters of the first convolutional layer from 3-channel convolution to 1-channel convolution.
(For example, the following figure shows the change of the convolutional layer after changing the three channels to a single channel)

insert image description here

We know that the grayscale image is the weighted average of each channel of the three-channel image, so we can assume that after changing to a single channel, add (sum) the corresponding positions of the convolution matrix corresponding to the three channels to obtain a convolution matrix of one channel, and then Deconvolute the grayscale image, so that the feature extraction ability of the image is hardly compromised.

Let's take the Resnet50 pre-training model as an example to modify the parameters of the first convolutional layer so that it can be used for single-channel image training.

1. Export model parameters and modify parameters

To modify the parameters of the model in Pytorch, if it involves changes in the network structure, you need to modify the network structure first and then assign parameter values.
That is, export the pre-training model parameters → modify the pre-training model parameters → modify the network structure of the model → import back the modified model

from torchvision.models import resnet50
net = resnet50(pretrained=True)
print(net.conv1)  # 查看第一个卷积层的结构

weights = net.state_dict()  # state_dict()以 有序字典 罗列参数
print(weights.keys())  # 查看参数的key
weights['conv1.weight'].shape  # 根据key取到参数,查看形状

insert image description here

modify model parameters

weights['conv1.weight'] = weights['conv1.weight'].sum(1, keepdim=True)  # 修改第一个卷积层的参数,从3通道卷积改成1通道卷积
weights['conv1.weight'].shape  # 查看修改后的形状

insert image description here

2. Modify the model structure and import parameters back

import torch.nn as nn

# 修改第一个卷积层的结构
net.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# 导入修改后的参数
net.load_state_dict(weights)
net

insert image description here

Successfully modified!
Next, you can fine-tune this pre-trained model with your own grayscale dataset.

Guess you like

Origin blog.csdn.net/takedachia/article/details/130302513