Transfer Learning: Building an Efficient Fruit Recognition Model Using Restnet Pre-Training Models

Table of contents

introduction

1 Transfer learning

1.1 What is transfer learning

1.2 What problems can transfer learning solve?

1.3 Three problems faced by transfer learning

1.3.1 When to migrate

1.3.2 Where to migrate

1.3.3 How to migrate

1.4 Classification of Transfer Learning

1.4.1 Classification by learning style

1.4.2 Classification by method of use

2 Restnet network

2.1 Introduction to Restnet

 2.2 Restnet network structure

3 Transfer learning code implementation

3.1 Dataset Introduction

3.2 Pre-training model download

3.3 Migration learning using Restnet pre-training model based on pytorch

3.4 Network training without transfer learning based on pytorch

4 Summary


introduction

Based on the Restnet pre-training model, this project built a fruit classification recognition model through transfer learning. After 30epochs training, the model achieved rapid convergence, and the accuracy rate reached more than 96%. Through the actual combat of this project, we are further familiar with how to perform transfer learning on the basis of the pre-trained model and build a new deep learning model. Technical partners can refer to this project to add data classification or use other data sets to build a new classification recognition model through transfer learning.

1 Transfer learning

1.1 What is transfer learning

Transfer Learning is a machine learning method that takes the model developed for task A as the initial point and reuses it in the process of developing the model for task B. Transfer learning is the new task of improving learning by transferring knowledge from related tasks that have already been learned. Although most machine learning algorithms are designed to solve a single task, the development of algorithms that facilitate transfer learning is an ongoing concern of the machine learning community. topic of.

Generally speaking, transfer learning is the ability to learn to draw inferences from one instance. It uses existing knowledge to learn new knowledge. Its core is to find the similarity between existing knowledge and new knowledge. Through the transfer of this similarity, transfer learning Purpose. Everything in the world has something in common, how to reasonably find the similarities between them, and then use this bridge to help learn new knowledge is the core issue of transfer learning.

1.2 What problems can transfer learning solve?

Most of the current artificial intelligence technologies need to be supported by a large amount of high-quality data. Using laboratory-constructed data can solve this problem to a certain extent and meet basic training needs. However, when it is used on the actual site, the predicted result is often not accurate enough due to the difference between the constructed data and the actual data. The emergence of this problem has put forward new requirements for the AI ​​​​algorithm - on the basis of making full use of the data constructed in the laboratory, it is also necessary to obtain good results on the real data of the site. As a solution, "transfer learning" helps us to train a model that requires massive data training effects on a limited data set, which can achieve twice the result with half the effort.

The application of transfer learning is often not limited to a specific field. As long as the problem meets the scenario of transfer learning, transfer learning can be used to solve it. Computer vision, text classification, behavior recognition, natural language processing, indoor positioning, video surveillance, public opinion analysis, human-computer interaction and other fields can all use transfer learning technology.

1.3 Three problems faced by transfer learning

1.3.1 When to migrate

When to transfer corresponds to the possibility of transfer learning and the reason for using transfer learning. It is worth noting that this step should occur in the first step of transfer learning. Given the target to be learned, the first thing we need to do is to judge the Is the task suitable for transfer learning

1.3.2 Where to migrate

After judging that the task at that time is suitable for transfer learning, the second step is to solve where to transfer. Here and where we can use what and where to express for easy understanding. what refers to what knowledge to transfer, which can be Neural network weights, certain parameters of the feature change matrix, etc.; and where refers to the place to migrate from, which can be a certain source domain, a certain neuron, a tree of a random forest, etc.

1.3.3 How to migrate

This step is the focus of most transfer learning methods. Given the source domain and target domain to be learned, this step is to learn the optimal transfer learning method to achieve the best performance.

1.4 Classification of Transfer Learning

1.4.1 Classification by learning style

  • Instance based Transfer Learning: Different weights are given to different samples. The more similar the samples, the higher the weights. High-weight samples have higher priority and are migrated.
  • Feature based transfer learning (Feature based Transfer Learning): It is to transform the features. Assuming that the features of the source domain and the target domain are not in the same space, or that they are not similar in the original space, then we find a way to transform them into a space, minimize the distance between the corresponding points of the two, and complete the migration
  • Model based transfer learning (Model based Transfer Learning): It is to reuse the parameters in the model. This type of method is especially used in neural networks, because the structure of neural networks can be directly transferred. Than finetune is a good embodiment of model parameter migration.
  • Relation-based transfer learning: To explore the relationship of similar scenes, if the relationship between the source domain and the target domain in the two groups can be determined, it can be transferred to another group. For example, teachers attending classes and students attending classes can be compared to the scene of a company meeting. This analogy is a kind of relationship transfer.

1.4.2 Classification by method of use

  • Finetune: Modify the network model trained by others and use it for your own. Use pretrained network weights instead of randomly initialized weights
  • Fixed Feature Extractor: The pre-trained network is used as a new task feature extraction, that is, the first few layers of the new model are replaced with the pre-trained network

2 Restnet network

2.1 Introduction to Restnet

Deep learning has made major breakthroughs in areas such as image classification, object detection, and speech recognition. However, as the number of network layers increases, the problems of gradient disappearance and gradient explosion gradually become prominent. As the number of layers increases, the gradient information gradually becomes smaller during the backpropagation process, making it difficult for the network to converge. At the same time, the gradient explosion problem will also cause the parameter update of the network to be too large and cannot converge normally.

In order to solve these problems, ResNet proposes an innovative idea: introducing a residual block (Residual Block). The design of the residual block allows the network to learn the residual mapping, which alleviates the vanishing gradient problem and makes the network easier to train.

The figure below is a basic residual block. Its operation is to connect the input of a certain layer to the activation layer of the next layer or even deeper before jumping, and output through the activation function together with the output of this layer.
 

 2.2 Restnet network structure

 The classic network structures of ResNet are: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152. Among them, ResNet-18 and ResNet-34 have the same basic structure and belong to relatively shallow networks. The latter three belong to deeper networks, among which RestNet50 is the most commonly used.

3 Transfer learning code implementation

3.1 Dataset Introduction

Under the dataset directory, there are 10 folders, the folder name is fruit type, and each folder contains hundreds to thousands of pictures of this type of fruit, as shown in the following figure:

 Take the apple folder as an example, the content is as follows:

Download address: the first package , the second package

After the two data packages are downloaded, they are decompressed to the /opt/dataset/fruit directory, as shown below after completion:

# ll fruit/
总用量 508
drwxr-xr-x 2 root root 36864 8月   2 16:35 apple
drwxr-xr-x 2 root root 24576 8月   2 16:36 apricot
drwxr-xr-x 2 root root 40960 8月   2 16:36 banana
drwxr-xr-x 2 root root 20480 8月   2 16:36 blueberry
drwxr-xr-x 2 root root 45056 8月   2 16:37 cherry
drwxr-xr-x 2 root root 12288 8月   2 16:37 citrus
drwxr-xr-x 2 root root 49152 8月   2 16:38 grape
drwxr-xr-x 2 root root 16384 8月   2 16:38 lemon
drwxr-xr-x 2 root root 36864 8月   2 16:39 litchi
drwxr-xr-x 2 root root 49152 8月   2 16:39 mango

3.2 Pre-training model download

The download address of the pre-trained model is as follows:

    model_urls = {
        'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
        'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
        'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
        'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
        'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
        'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
        'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
        'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
        'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
    }

 Download the restnet50 model and store it in the /opt/models directory

3.3 Migration learning using Restnet pre-training model based on pytorch

Based on Restnet, the fruit recognition model adds a fully connected layer to convert the 2048 output of the Restnet pre-training model into the output of the corresponding number of fruit classifications.

import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import datetime
import numpy as np

from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.models import resnet50
from sklearn.model_selection import train_test_split

# 图像变换
transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize(
                                     mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]
                                ), ])
# 加载数据集
dataset = ImageFolder('/opt/dataset/fruit', transform=transform)

# 划分训练集与测试集
train_dataset, valid_dataset = train_test_split(dataset, test_size=0.2, random_state=10)

batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

restnet_pretrained_path = '/opt/models/resnet50-0676ba61.pth'
checkpoint_path = '/opt/checkpoint/fruit_reg.pth'
checkpoint_resume = False

if __name__ == "__main__":
    # 加载预训练模型
    model = resnet50()
    model.load_state_dict(torch.load(restnet_pretrained_path))

    # 替换最后一层全连接层,构建新的网络,实现迁移学习
    num_classes = len(dataset.classes)
    in_features = model.fc.in_features
    model.fc = torch.nn.Linear(in_features, num_classes)

    # 模型训练
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
    num_epochs = 30

    accuracy_rate = []
    for epoch in range(num_epochs):
        print('Epoch [{}/{}], start'.format(epoch + 1, num_epochs))

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, loss.item()))

        # 模型验证
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total * 100
        accuracy_rate.append(accuracy)
        print('Accuracy: {:.2f}%'.format(accuracy))

    accuracy_rate = np.array(accuracy_rate)
    times = np.linspace(1, num_epochs, num_epochs)
    plt.xlabel('times')
    plt.ylabel('accuracy rate')
    plt.plot(times, accuracy_rate)
    plt.show()

    print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')},accuracy_rate={accuracy_rate}")
    torch.save(model.state_dict(), checkpoint_path)
Epoch [1/30], Loss: 1.4853     Accuracy: 63.69% 

Epoch [2/30], Loss: 0.2206     Accuracy: 92.35%

Epoch [3/30], Loss: 0.1856     Accuracy: 94.56%

Epoch [4/30], Loss: 0.1025     Accuracy: 93.97%

Epoch [5/30], Loss: 0.0543     Accuracy: 95.31%

Epoch [6/30], Loss: 0.0335     Accuracy: 95.80%

Epoch [7/30], Loss: 0.0114     Accuracy: 95.64%

Epoch [8/30], Loss: 0.0159     Accuracy: 95.20%

Epoch [9/30], Loss: 0.0060     Accuracy: 95.96%

Epoch [10/30], Loss: 0.0027    Accuracy: 96.01%

Epoch [11/30], Loss: 0.0052    Accuracy: 96.07%

Epoch [12/30], Loss: 0.0030    Accuracy: 96.01%

Epoch [13/30], Loss: 0.0035    Accuracy: 96.01%

Epoch [14/30], Loss: 0.0026    Accuracy: 96.12%

Epoch [15/30], Loss: 0.0008    Accuracy: 95.96%

Epoch [16/30], Loss: 0.0013    Accuracy: 96.01%

Epoch [17/30], Loss: 0.0008    Accuracy: 96.17%

Epoch [18/30], Loss: 0.0005    Accuracy: 96.01%

Epoch [19/30], Loss: 0.0010    Accuracy: 96.07%

Epoch [20/30], Loss: 0.0009    Accuracy: 96.07%

Epoch [21/30], Loss: 0.0002    Accuracy: 95.96%

Epoch [22/30], Loss: 0.0002    Accuracy: 96.01%

Epoch [23/30], Loss: 0.0006    Accuracy: 96.39%

Epoch [24/30], Loss: 0.0010    Accuracy: 96.12%

Epoch [25/30], Loss: 0.0008    Accuracy: 96.07%

Epoch [26/30], Loss: 0.0011    Accuracy: 96.01%

Epoch [27/30], Loss: 0.0003    Accuracy: 96.07%

Epoch [28/30], Loss: 0.0006    Accuracy: 96.07%

Epoch [29/30], Loss: 0.0005    Accuracy: 96.07%

Epoch [30/30], Loss: 0.0002    Accuracy: 96.23%

After 10 epochs, the model starts to converge, and the accuracy change curve of 30 epochs is as follows:

3.4 Network training without transfer learning based on pytorch

import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import datetime
import numpy as np

from torchvision.datasets import ImageFolder
from torchvision import transforms

from torchvision.models import resnet50

from sklearn.model_selection import train_test_split

# 图像变换
transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize(
                                     mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]
                                ), ])
# 加载数据集
dataset = ImageFolder('./data/fruit', transform=transform)

# 划分训练集与测试集
train_dataset, valid_dataset = train_test_split(dataset, test_size=0.2, random_state=0)

batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

checkpoint_path = './checkpoint/fruit_reg.pth'
checkpoint_resume = False

if __name__ == "__main__":
    # 加载预训练模型
    model = resnet50()

    # 替换最后一层全连接层,构建新的网络
    num_classes = len(dataset.classes)
    in_features = model.fc.in_features
    model.fc = torch.nn.Linear(in_features, num_classes)

    # 模型训练
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    num_epochs = 10

    accuracy_rate = []
    for epoch in range(num_epochs):

        print('Epoch [{}/{}], start'.format(epoch + 1, num_epochs))

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, loss.item()))

        # 模型验证
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total * 100
        accuracy_rate.append(accuracy)
        print('Accuracy: {:.2f}%'.format(accuracy))

    accuracy_rate = np.array(accuracy_rate)
    times = np.linspace(1, num_epochs, num_epochs)
    plt.xlabel('times')
    plt.ylabel('accuracy rate')
    plt.plot(times, accuracy_rate)
    plt.show()

    print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')},accuracy_rate={accuracy_rate}")
    torch.save(model.state_dict(), checkpoint_path)
Epoch [1/30], Loss: 2.1676  Accuracy: 18.32%

Epoch [2/30], Loss: 1.9645  Accuracy: 20.85%

Epoch [3/30], Loss: 1.9394  Accuracy: 37.55%

Epoch [4/30], Loss: 1.3242  Accuracy: 40.46%

Epoch [5/30], Loss: 1.1633  Accuracy: 48.38%

Epoch [6/30], Loss: 1.4852  Accuracy: 52.80%

Epoch [7/30], Loss: 1.0438  Accuracy: 55.01%

Epoch [8/30], Loss: 1.2010  Accuracy: 52.86%

Epoch [9/30], Loss: 0.9826  Accuracy: 55.28%

Epoch [10/30], Loss: 1.0562  Accuracy: 53.72%

Epoch [11/30], Loss: 1.2049  Accuracy: 61.15%

Epoch [12/30], Loss: 1.0919  Accuracy: 59.91%

Epoch [13/30], Loss: 0.7103  Accuracy: 59.81%

Epoch [14/30], Loss: 0.7970  Accuracy: 61.64%

Epoch [15/30], Loss: 1.4505  Accuracy: 60.56%

Epoch [16/30], Loss: 1.0294  Accuracy: 60.02%

Epoch [17/30], Loss: 1.0225  Accuracy: 55.39%

Epoch [18/30], Loss: 0.9417  Accuracy: 64.33%

Epoch [19/30], Loss: 0.7826  Accuracy: 66.06%

Epoch [20/30], Loss: 0.8774  Accuracy: 65.09%

Epoch [21/30], Loss: 0.9671  Accuracy: 63.36%

Epoch [22/30], Loss: 0.7064  Accuracy: 66.81%

Epoch [23/30], Loss: 0.6465  Accuracy: 65.89%

Epoch [24/30], Loss: 0.7217  Accuracy: 64.55%

Epoch [25/30], Loss: 0.7089  Accuracy: 68.05%

Epoch [26/30], Loss: 0.8506  Accuracy: 66.76%

Epoch [27/30], Loss: 0.9541  Accuracy: 67.73%

Epoch [28/30], Loss: 1.1595  Accuracy: 68.21%

Epoch [29/30], Loss: 0.8493  Accuracy: 68.59%

Epoch [30/30], Loss: 0.8297  Accuracy: 71.55%

If transfer learning is not used (that is, the pre-trained model of Restnet50 is not loaded), after 30 epochs training, the model does not converge, and the obtained accuracy rate curve is as follows:

4 Summary

Based on the Restnet50 pre-training model, this project implements model training and recognition on fruit data through transfer learning. After 30 epochs, the accuracy rate on the test set has reached 96%, and the model has completed rapid convergence.

From the perspective of training effect, whether it is accuracy or convergence speed, the network after transfer learning is much higher than the network without transfer learning, which fully reflects the value of transfer learning.

This migration learning training method can be extended and applied to other data sets. In the case of limited data sets and computing resources, using migration learning can quickly train a good classification recognition model.

 Complete project code: code address

Guess you like

Origin blog.csdn.net/lsb2002/article/details/131923971
Recommended