[Computer Vision | Pytorch] Specific introduction of timm package and image classification case (including source code)

1. Specific introduction

timm is a computer vision model library implemented natively in PyTorch . It provides pre-trained models and various network components that can be used for various computer vision tasks, such as image classification, object detection, semantic segmentation, and more.

Timm features are as follows:

  1. PyTorch native implementation: The implementation of timm is highly compatible with PyTorch , and developers can easily use PyTorch 's API for model training and deployment.
  2. Lightweight design: Timm 's design is based on lightweight, and provides a variety of lightweight network structures according to different computer vision tasks.
  3. A large number of pre-trained models: timm provides a large number of pre-trained models that can be directly used for various computer vision tasks.
  4. Various model components: timm provides various model components, such as attention module, regularization module, activation function, etc. These modules can be easily inserted into your own model.
  5. Efficient code implementation: Timm 's code implementation is efficient and easy to use.

It should be noted that timm is a community-driven project developed and maintained by experts in the field of computer vision. When using it, you need to follow the relevant usage agreement.

2. Image classification case

The following is a brief introduction by using timm to implement image classification tasks as an example.

2.1 Install the timm package

!pip install timm

2.2 Import related modules and read data sets

import torch
import torch.nn as nn
import timm
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

# 数据增强
train_transforms = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# 数据集
train_dataset = CIFAR10(root='data', train=True, download=True, transform=train_transforms)
test_dataset = CIFAR10(root='data', train=False, download=True, transform=test_transforms)

# DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

Import related modules, among which timm and torchvision.datasets.CIFAR10 need to install timm and torchvision packages respectively .

Define the method of data enhancement, in which different enhancement methods are used for the training set and test set, and the images are normalized. transforms.Compose() can package various operations into a transform operation flow, transforms.ToTensor() converts the image into tensor format, and transforms.Normalize() normalizes the image.

Use the built-in CIFAR10 dataset, set train=True to define the training set, and set train=False to define the test set. The data set will be automatically downloaded to the specified root path, and the data enhancement operation will be performed.

Use torch.utils.data.DataLoader to define the data loader, wrap the dataset into an efficient iterable object, where batch_size defines the batch size, shuffle defines whether to randomly shuffle the data, and num_workers defines how many workers are used to load data.

insert image description here

2.3 Define the model

# 加载预训练模型
model = timm.create_model('resnet18', pretrained=True)

# 修改分类器
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(train_dataset.classes))

The timm.create_model() function is used here to create a pre-trained model, where the parameter resnet18 defines the model architecture used, and the parameter pretrained = True indicates that pre-trained weights are to be used.

The classifier of the model is modified here. First, use model.fc.in_features to obtain the input feature number of the fc layer of the model, and then use nn.Linear() to redefine a nn.Linear layer. The input is the output feature number of the previous layer. The output is the number of classes (ie len(train_dataset.classes) ). Here, the number of dataset categories is directly used to define the output layer to adapt to the needs of different classification tasks.

insert image description here
Here, we use the ResNet18 model in timm and modify it to the classifier we need. At the same time, when creating the model, set the parameter pretrained=True to load the pre-trained weight.

2.4 Define loss function and optimizer

# 损失函数
criterion = nn.CrossEntropyLoss()

# 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In deep learning, the loss function is an indicator to evaluate the difference between the model prediction result and the real label, and is often used in the model training process. nn.CrossEntropyLoss() is a commonly used loss function for multi-classification problems.

The optimizer is used to update the model parameters to minimize the loss function. Here, we use a Stochastic Gradient Descent ( SGD ) optimizer to control the variation of the model weights. Specify the parameters to be optimized through model.parameters() , lr defines the learning rate, indicating the amount of parameters that must be updated in each iteration, momentum is to add a part of the update value of the last iteration to the update value of this time, To reduce the variance of parameter updates and stabilize the training process.

2.5 Training Model

num_epochs = 10

for epoch in range(num_epochs):
    # 训练
    model.train()
    for images, labels in train_loader:
        # 前向传播
        outputs = model(images)
        # 计算损失
        loss = criterion(outputs, labels)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # 测试
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Epoch {} Accuracy: {:.2f}%'.format(epoch+1, 100*correct/total))

This code is the model training and testing loop. num_epochs defines the number of cycles, each cycle represents a training cycle.

In the training phase, first switch the model to the training mode, and then use train_loader to read the training set data iteratively, perform operations such as forward propagation, calculation loss, back propagation, and optimizer update.

In the test phase, the model switches to evaluation mode, then uses test_loader to read the test set data, performs forward propagation and calculates the model prediction results, uses the prediction results and real labels for accuracy calculation, and outputs the accuracy rate of each training cycle.

Among them, the torch.max() function is used to return the maximum value and its index in each row, total records the total number of test samples, correct records the number of correctly classified samples, and finally calculates the accuracy rate and outputs it.

The output is:

insert image description here

Guess you like

Origin blog.csdn.net/wzk4869/article/details/130666405