Table of contents
1.2 What problems can transfer learning solve?
1.3 Three problems faced by transfer learning
1.4 Classification of Transfer Learning
1.4.1 Classification by learning style
1.4.2 Classification by method of use
3 Transfer learning code implementation
3.2 Pre-training model download
3.3 Migration learning using Restnet pre-training model based on pytorch
3.4 Network training without transfer learning based on pytorch
introduction
Based on the Restnet pre-training model, this project built a fruit classification recognition model through transfer learning. After 30epochs training, the model achieved rapid convergence, and the accuracy rate reached more than 96%. Through the actual combat of this project, we are further familiar with how to perform transfer learning on the basis of the pre-trained model and build a new deep learning model. Technical partners can refer to this project to add data classification or use other data sets to build a new classification recognition model through transfer learning.
1 Transfer learning
1.1 What is transfer learning
Transfer Learning is a machine learning method that takes the model developed for task A as the initial point and reuses it in the process of developing the model for task B. Transfer learning is the new task of improving learning by transferring knowledge from related tasks that have already been learned. Although most machine learning algorithms are designed to solve a single task, the development of algorithms that facilitate transfer learning is an ongoing concern of the machine learning community. topic of.
Generally speaking, transfer learning is the ability to learn to draw inferences from one instance. It uses existing knowledge to learn new knowledge. Its core is to find the similarity between existing knowledge and new knowledge. Through the transfer of this similarity, transfer learning Purpose. Everything in the world has something in common, how to reasonably find the similarities between them, and then use this bridge to help learn new knowledge is the core issue of transfer learning.
1.2 What problems can transfer learning solve?
Most of the current artificial intelligence technologies need to be supported by a large amount of high-quality data. Using laboratory-constructed data can solve this problem to a certain extent and meet basic training needs. However, when it is used on the actual site, the predicted result is often not accurate enough due to the difference between the constructed data and the actual data. The emergence of this problem has put forward new requirements for the AI algorithm - on the basis of making full use of the data constructed in the laboratory, it is also necessary to obtain good results on the real data of the site. As a solution, "transfer learning" helps us to train a model that requires massive data training effects on a limited data set, which can achieve twice the result with half the effort.
The application of transfer learning is often not limited to a specific field. As long as the problem meets the scenario of transfer learning, transfer learning can be used to solve it. Computer vision, text classification, behavior recognition, natural language processing, indoor positioning, video surveillance, public opinion analysis, human-computer interaction and other fields can all use transfer learning technology.
1.3 Three problems faced by transfer learning
1.3.1 When to migrate
When to transfer corresponds to the possibility of transfer learning and the reason for using transfer learning. It is worth noting that this step should occur in the first step of transfer learning. Given the target to be learned, the first thing we need to do is to judge the Is the task suitable for transfer learning
1.3.2 Where to migrate
After judging that the task at that time is suitable for transfer learning, the second step is to solve where to transfer. Here and where we can use what and where to express for easy understanding. what refers to what knowledge to transfer, which can be Neural network weights, certain parameters of the feature change matrix, etc.; and where refers to the place to migrate from, which can be a certain source domain, a certain neuron, a tree of a random forest, etc.
1.3.3 How to migrate
This step is the focus of most transfer learning methods. Given the source domain and target domain to be learned, this step is to learn the optimal transfer learning method to achieve the best performance.
1.4 Classification of Transfer Learning
1.4.1 Classification by learning style
- Instance based Transfer Learning: Different weights are given to different samples. The more similar the samples, the higher the weights. High-weight samples have higher priority and are migrated.
- Feature based transfer learning (Feature based Transfer Learning): It is to transform the features. Assuming that the features of the source domain and the target domain are not in the same space, or that they are not similar in the original space, then we find a way to transform them into a space, minimize the distance between the corresponding points of the two, and complete the migration
- Model based transfer learning (Model based Transfer Learning): It is to reuse the parameters in the model. This type of method is especially used in neural networks, because the structure of neural networks can be directly transferred. Than finetune is a good embodiment of model parameter migration.
- Relation-based transfer learning: To explore the relationship of similar scenes, if the relationship between the source domain and the target domain in the two groups can be determined, it can be transferred to another group. For example, teachers attending classes and students attending classes can be compared to the scene of a company meeting. This analogy is a kind of relationship transfer.
1.4.2 Classification by method of use
- Finetune: Modify the network model trained by others and use it for your own. Use pretrained network weights instead of randomly initialized weights
- Fixed Feature Extractor: The pre-trained network is used as a new task feature extraction, that is, the first few layers of the new model are replaced with the pre-trained network
2 Restnet network
2.1 Introduction to Restnet
Deep learning has made major breakthroughs in areas such as image classification, object detection, and speech recognition. However, as the number of network layers increases, the problems of gradient disappearance and gradient explosion gradually become prominent. As the number of layers increases, the gradient information gradually becomes smaller during the backpropagation process, making it difficult for the network to converge. At the same time, the gradient explosion problem will also cause the parameter update of the network to be too large and cannot converge normally.
In order to solve these problems, ResNet proposes an innovative idea: introducing a residual block (Residual Block). The design of the residual block allows the network to learn the residual mapping, which alleviates the vanishing gradient problem and makes the network easier to train.
The figure below is a basic residual block. Its operation is to connect the input of a certain layer to the activation layer of the next layer or even deeper before jumping, and output through the activation function together with the output of this layer.
2.2 Restnet network structure
The classic network structures of ResNet are: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152. Among them, ResNet-18 and ResNet-34 have the same basic structure and belong to relatively shallow networks. The latter three belong to deeper networks, among which RestNet50 is the most commonly used.
3 Transfer learning code implementation
3.1 Dataset Introduction
Under the dataset directory, there are 10 folders, the folder name is fruit type, and each folder contains hundreds to thousands of pictures of this type of fruit, as shown in the following figure:
Take the apple folder as an example, the content is as follows:
Download address: the first package , the second package
After the two data packages are downloaded, they are decompressed to the /opt/dataset/fruit directory, as shown below after completion:
# ll fruit/
总用量 508
drwxr-xr-x 2 root root 36864 8月 2 16:35 apple
drwxr-xr-x 2 root root 24576 8月 2 16:36 apricot
drwxr-xr-x 2 root root 40960 8月 2 16:36 banana
drwxr-xr-x 2 root root 20480 8月 2 16:36 blueberry
drwxr-xr-x 2 root root 45056 8月 2 16:37 cherry
drwxr-xr-x 2 root root 12288 8月 2 16:37 citrus
drwxr-xr-x 2 root root 49152 8月 2 16:38 grape
drwxr-xr-x 2 root root 16384 8月 2 16:38 lemon
drwxr-xr-x 2 root root 36864 8月 2 16:39 litchi
drwxr-xr-x 2 root root 49152 8月 2 16:39 mango
3.2 Pre-training model download
The download address of the pre-trained model is as follows:
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
Download the restnet50 model and store it in the /opt/models directory
3.3 Migration learning using Restnet pre-training model based on pytorch
Based on Restnet, the fruit recognition model adds a fully connected layer to convert the 2048 output of the Restnet pre-training model into the output of the corresponding number of fruit classifications.
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import datetime
import numpy as np
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.models import resnet50
from sklearn.model_selection import train_test_split
# 图像变换
transform = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
), ])
# 加载数据集
dataset = ImageFolder('/opt/dataset/fruit', transform=transform)
# 划分训练集与测试集
train_dataset, valid_dataset = train_test_split(dataset, test_size=0.2, random_state=10)
batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
restnet_pretrained_path = '/opt/models/resnet50-0676ba61.pth'
checkpoint_path = '/opt/checkpoint/fruit_reg.pth'
checkpoint_resume = False
if __name__ == "__main__":
# 加载预训练模型
model = resnet50()
model.load_state_dict(torch.load(restnet_pretrained_path))
# 替换最后一层全连接层,构建新的网络,实现迁移学习
num_classes = len(dataset.classes)
in_features = model.fc.in_features
model.fc = torch.nn.Linear(in_features, num_classes)
# 模型训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
num_epochs = 30
accuracy_rate = []
for epoch in range(num_epochs):
print('Epoch [{}/{}], start'.format(epoch + 1, num_epochs))
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, loss.item()))
# 模型验证
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total * 100
accuracy_rate.append(accuracy)
print('Accuracy: {:.2f}%'.format(accuracy))
accuracy_rate = np.array(accuracy_rate)
times = np.linspace(1, num_epochs, num_epochs)
plt.xlabel('times')
plt.ylabel('accuracy rate')
plt.plot(times, accuracy_rate)
plt.show()
print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')},accuracy_rate={accuracy_rate}")
torch.save(model.state_dict(), checkpoint_path)
Epoch [1/30], Loss: 1.4853 Accuracy: 63.69%
Epoch [2/30], Loss: 0.2206 Accuracy: 92.35%
Epoch [3/30], Loss: 0.1856 Accuracy: 94.56%
Epoch [4/30], Loss: 0.1025 Accuracy: 93.97%
Epoch [5/30], Loss: 0.0543 Accuracy: 95.31%
Epoch [6/30], Loss: 0.0335 Accuracy: 95.80%
Epoch [7/30], Loss: 0.0114 Accuracy: 95.64%
Epoch [8/30], Loss: 0.0159 Accuracy: 95.20%
Epoch [9/30], Loss: 0.0060 Accuracy: 95.96%
Epoch [10/30], Loss: 0.0027 Accuracy: 96.01%
Epoch [11/30], Loss: 0.0052 Accuracy: 96.07%
Epoch [12/30], Loss: 0.0030 Accuracy: 96.01%
Epoch [13/30], Loss: 0.0035 Accuracy: 96.01%
Epoch [14/30], Loss: 0.0026 Accuracy: 96.12%
Epoch [15/30], Loss: 0.0008 Accuracy: 95.96%
Epoch [16/30], Loss: 0.0013 Accuracy: 96.01%
Epoch [17/30], Loss: 0.0008 Accuracy: 96.17%
Epoch [18/30], Loss: 0.0005 Accuracy: 96.01%
Epoch [19/30], Loss: 0.0010 Accuracy: 96.07%
Epoch [20/30], Loss: 0.0009 Accuracy: 96.07%
Epoch [21/30], Loss: 0.0002 Accuracy: 95.96%
Epoch [22/30], Loss: 0.0002 Accuracy: 96.01%
Epoch [23/30], Loss: 0.0006 Accuracy: 96.39%
Epoch [24/30], Loss: 0.0010 Accuracy: 96.12%
Epoch [25/30], Loss: 0.0008 Accuracy: 96.07%
Epoch [26/30], Loss: 0.0011 Accuracy: 96.01%
Epoch [27/30], Loss: 0.0003 Accuracy: 96.07%
Epoch [28/30], Loss: 0.0006 Accuracy: 96.07%
Epoch [29/30], Loss: 0.0005 Accuracy: 96.07%
Epoch [30/30], Loss: 0.0002 Accuracy: 96.23%
After 10 epochs, the model starts to converge, and the accuracy change curve of 30 epochs is as follows:
3.4 Network training without transfer learning based on pytorch
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import datetime
import numpy as np
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.models import resnet50
from sklearn.model_selection import train_test_split
# 图像变换
transform = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
), ])
# 加载数据集
dataset = ImageFolder('./data/fruit', transform=transform)
# 划分训练集与测试集
train_dataset, valid_dataset = train_test_split(dataset, test_size=0.2, random_state=0)
batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
checkpoint_path = './checkpoint/fruit_reg.pth'
checkpoint_resume = False
if __name__ == "__main__":
# 加载预训练模型
model = resnet50()
# 替换最后一层全连接层,构建新的网络
num_classes = len(dataset.classes)
in_features = model.fc.in_features
model.fc = torch.nn.Linear(in_features, num_classes)
# 模型训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
num_epochs = 10
accuracy_rate = []
for epoch in range(num_epochs):
print('Epoch [{}/{}], start'.format(epoch + 1, num_epochs))
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, loss.item()))
# 模型验证
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total * 100
accuracy_rate.append(accuracy)
print('Accuracy: {:.2f}%'.format(accuracy))
accuracy_rate = np.array(accuracy_rate)
times = np.linspace(1, num_epochs, num_epochs)
plt.xlabel('times')
plt.ylabel('accuracy rate')
plt.plot(times, accuracy_rate)
plt.show()
print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')},accuracy_rate={accuracy_rate}")
torch.save(model.state_dict(), checkpoint_path)
Epoch [1/30], Loss: 2.1676 Accuracy: 18.32%
Epoch [2/30], Loss: 1.9645 Accuracy: 20.85%
Epoch [3/30], Loss: 1.9394 Accuracy: 37.55%
Epoch [4/30], Loss: 1.3242 Accuracy: 40.46%
Epoch [5/30], Loss: 1.1633 Accuracy: 48.38%
Epoch [6/30], Loss: 1.4852 Accuracy: 52.80%
Epoch [7/30], Loss: 1.0438 Accuracy: 55.01%
Epoch [8/30], Loss: 1.2010 Accuracy: 52.86%
Epoch [9/30], Loss: 0.9826 Accuracy: 55.28%
Epoch [10/30], Loss: 1.0562 Accuracy: 53.72%
Epoch [11/30], Loss: 1.2049 Accuracy: 61.15%
Epoch [12/30], Loss: 1.0919 Accuracy: 59.91%
Epoch [13/30], Loss: 0.7103 Accuracy: 59.81%
Epoch [14/30], Loss: 0.7970 Accuracy: 61.64%
Epoch [15/30], Loss: 1.4505 Accuracy: 60.56%
Epoch [16/30], Loss: 1.0294 Accuracy: 60.02%
Epoch [17/30], Loss: 1.0225 Accuracy: 55.39%
Epoch [18/30], Loss: 0.9417 Accuracy: 64.33%
Epoch [19/30], Loss: 0.7826 Accuracy: 66.06%
Epoch [20/30], Loss: 0.8774 Accuracy: 65.09%
Epoch [21/30], Loss: 0.9671 Accuracy: 63.36%
Epoch [22/30], Loss: 0.7064 Accuracy: 66.81%
Epoch [23/30], Loss: 0.6465 Accuracy: 65.89%
Epoch [24/30], Loss: 0.7217 Accuracy: 64.55%
Epoch [25/30], Loss: 0.7089 Accuracy: 68.05%
Epoch [26/30], Loss: 0.8506 Accuracy: 66.76%
Epoch [27/30], Loss: 0.9541 Accuracy: 67.73%
Epoch [28/30], Loss: 1.1595 Accuracy: 68.21%
Epoch [29/30], Loss: 0.8493 Accuracy: 68.59%
Epoch [30/30], Loss: 0.8297 Accuracy: 71.55%
If transfer learning is not used (that is, the pre-trained model of Restnet50 is not loaded), after 30 epochs training, the model does not converge, and the obtained accuracy rate curve is as follows:
4 Summary
Based on the Restnet50 pre-training model, this project implements model training and recognition on fruit data through transfer learning. After 30 epochs, the accuracy rate on the test set has reached 96%, and the model has completed rapid convergence.
From the perspective of training effect, whether it is accuracy or convergence speed, the network after transfer learning is much higher than the network without transfer learning, which fully reflects the value of transfer learning.
This migration learning training method can be extended and applied to other data sets. In the case of limited data sets and computing resources, using migration learning can quickly train a good classification recognition model.
Complete project code: code address