Use the self-supervised contrastive learning model SimCLR to complete the image classification task: detailed explanation of pytorch code

Code comes from:SimCLR image classification - pytorch reproduction
SimCLR framework: SimCLR framework analysis< /span>

1. Define supervised and unsupervised partial network structures and loss functions

1. 1 Task breakdown

Unsupervised part: Resnet50 is used for network feature extraction, the input layer is changed, and the pooling layer and fully connected layer are removed. Afterwards, the feature map is flattened, and fully connected, batch normalized, relu activated, and fully connected are performed in sequence to obtain the output features.
Supervised part: For downstream classification tasks, the feature extraction layer and parameters of the unsupervised learning network are used, and then a fully connected layer is used to obtain the classification output.
Loss function: Minimize the difference between the similarity between positive samples and the similarity between negative samples, thereby making the positive samples closer, Negative samples are farther away.

The downstream task is to take out the features extracted by the unsupervised part of the encoder and add a fully connected layer to classify the output

The paper uses the ResNet-50 structure as the convolutional network encoder and obtains a 1*2048 representation.

Insert image description here

1.2 Code

# net.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet50


# stage one ,unsupervised learning
class SimCLRStage1(nn.Module):
    def __init__(self, feature_dim=128):
        super(SimCLRStage1, self).__init__()

        self.f = []
        for name, module in resnet50().named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                self.f.append(module)
        # encoder
        self.f = nn.Sequential(*self.f)
        # projection head
        self.g = nn.Sequential(nn.Linear(2048, 512, bias=False),
                               nn.BatchNorm1d(512),
                               nn.ReLU(inplace=True),
                               nn.Linear(512, feature_dim, bias=True))

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.g(feature)
        return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)


# stage two ,supervised learning
class SimCLRStage2(torch.nn.Module):
    def __init__(self, num_class):
        super(SimCLRStage2, self).__init__()
        # encoder
        self.f = SimCLRStage1().f
        # classifier
        self.fc = nn.Linear(2048, num_class, bias=True)

        for param in self.f.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.fc(feature)
        return out


class Loss(torch.nn.Module):
    def __init__(self):
        super(Loss,self).__init__()

    def forward(self,out_1,out_2,batch_size,temperature=0.5):
        # [2*B, D]
        out = torch.cat([out_1, out_2], dim=0)
        # [2*B, 2*B]
        sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
        mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool()
        # [2*B, 2*B-1]
        sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)

        # 分子: *为对应位置相乘,也是点积
        # compute loss
        pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
        # [2*B]
        pos_sim = torch.cat([pos_sim, pos_sim], dim=0)
        return (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean()


if __name__=="__main__":
    for name, module in resnet50().named_children():
        print(name,module)


1.3 Detailed code explanation

1.3.1 SimCLRStage1

  • class SimCLRStage1

A class named SimCLRStage1 is defined, inheriting from nn.Module. This class is used to implement the first stage of self-supervised learning.

class SimCLRStage1(nn.Module):
    def __init__(self, feature_dim=128):
        super(SimCLRStage1, self).__init__()

        self.f = []
        for name, module in resnet50().named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                self.f.append(module)
        # encoder
        self.f = nn.Sequential(*self.f)
        # projection head
        self.g = nn.Sequential(nn.Linear(2048, 512, bias=False),
                               nn.BatchNorm1d(512),
                               nn.ReLU(inplace=True),
                               nn.Linear(512, feature_dim, bias=True))

In the constructor __init__, first iterate through all submodules of the resnet50 model through resnet50().named_children(). If the current submodule's name is 'conv1', replace it with a new nn.Conv2d module that handles channel number adjustment of the input image. Then, store all submodules except nn.Linear and nn.MaxPool2d in the list self.f as part of the encoder. Next, concatenate the submodules in the list into a sequence via nn.Sequential and store it in self.f. Then, a projection head is defined to map the features output by the encoder to a lower dimensional space. The projection head consists of two linear layers and a ReLU activation function, and the final output feature dimension is feature_dim.

  • forward propagation forward function
def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.g(feature)
        return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)

This method defines the process of forward propagation. The input x is first processed by the encoder self.f, which then flattens the feature tensor into a 2D shape via torch.flatten. Then, the flattened feature tensor is input to the projection head self.g to obtain the projected feature representation out. Finally, use F.normalize to normalize the features and return the normalized feature feature output by the encoder and the normalized feature output output by the projection head.

1.3.2 SimCLRStage2

  • class SimCLRStage2

This code defines a class named SimCLRStage2, which inherits from nn.Module. This class is used to implement the second stage of self-supervised learning, that is, supervised learning based on the features obtained in the first stage.

class SimCLRStage2(torch.nn.Module):
    def __init__(self, num_class):
        super(SimCLRStage2, self).__init__()
        # encoder
        self.f = SimCLRStage1().f
        # classifier
        self.fc = nn.Linear(2048, num_class, bias=True)

        for param in self.f.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.fc(feature)
        return out

In the constructor __init______, the first-stage encoder self.f is first obtained through SimCLRStage1().f, and then a linear classifier self.fc is defined to map the features output by the encoder into the category space. . At the end of the constructor, set the encoder's parameters to be non-trainable, i.e. requires_grad = False.

  • forward propagation forward function
def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.fc(feature)
        return out

This method defines the process of forward propagation. The input x is first processed by the encoder self.f, which then flattens the feature tensor into a 2D shape via torch.flatten. Then, the flattened feature tensor is input to the linear classifier self.fc to obtain the classification result out. Finally, return the classification result out.

1.3.3 loss function

SimCLR uses a loss function called NT-Xent loss, the full name is Normalized Temperature-Scaled Entropy Loss.

def forward(self, out_1, out_2, batch_size, temperature=0.5):
    # 拼接特征表示
    out = torch.cat([out_1, out_2], dim=0)

    # 计算相似性矩阵
    sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)

    # 创建掩码矩阵
    mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool()

    # 提取相似性矩阵中的有效元素
    sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)

    # 计算分子部分的相似性
    pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
    pos_sim = torch.cat([pos_sim, pos_sim], dim=0)

    # 计算损失函数
    loss = (-torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean()

    return loss
  1. out_1 and out_2 are two feature representations from the model with shape [batch_size, feature_dim].
  2. out is obtained by splicing out_1 and out_2 in the 0th dimension, and the shape is [2 * batch_size, feature_dim].
  3. torch.mm(out, out.t().contiguous()) calculates the product of out and its transposed matrix to obtain a similarity matrix. In order to ensure the correctness of the calculation, the .contiguous() method needs to be called to ensure that out is stored continuously in memory.
  4. The elements in the similarity matrix are normalized by dividing by the temperature parameter temperature, and exponential operation is performed using torch.exp to obtain the normalized similarity matrix sim_matrix.
  5. Next, create a mask matrix mask with the same shape as the similarity matrix. The diagonal elements of the mask matrix are 0 and the other elements are 1, which is used to exclude the similarity of each feature in the similarity matrix to itself.
  6. Use mask to mask the similarity matrix, extract effective similarity values, and adjust its shape to [2 * batch_size, 2 * batch_size - 1], where each row represents the similarity of one feature to other features.
  7. Use torch.sum(out_1 * out_2, dim=-1) to calculate the dot product similarity between out_1 and out_2, then normalize by dividing by the temperature parameter temperature, and use torch.exp to perform exponential operation to get pos_sim.
  8. pos_sim obtains a tensor of size [2 * batch_size] by splicing in the 0th dimension, which is used to compare with the similarity in the similarity matrix.
  9. The formula for calculating the loss function is: -torch.log(pos_sim / sim_matrix.sum(dim=-1)).mean(). Among them, the numerator part is pos_sim, and the denominator part is the sum of all similarities in the similarity matrix, and the negative logarithm and average operations are performed.
  10. Finally, the calculated loss value is returned.

2. Configuration file config.py

Configuration items define some parameters and data preprocessing operations during training and testing, making the code more flexible and configurable. These configuration items can be imported in other code and adjusted and used as needed.

2.1 Code

# config.py
import os
from torchvision import transforms
# use_gpu是一个布尔值,表示是否使用GPU进行训练。
# gpu_name是一个整数,表示使用的GPU设备的编号。
use_gpu=True
gpu_name=1

# pre_model是一个字符串,表示预训练模型的路径。
# os.path.join('pth','model.pth')用于将两个路径部分拼接成完整的路径。在这个例子中,预训练模型的路径为pth/model.pth。
pre_model=os.path.join('pth','model.pth')

# save_path是一个字符串,表示保存模型文件的路径。
# 在这个例子中,模型文件将保存在pth文件夹中
save_path="pth"

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

2.2 Explanation

train_transform is a torchvision.transforms.Compose object, used to define preprocessing operations for training data.
In this example, the preprocessing operations include, in order: randomly cropping to 32x32 size, randomly flipping horizontally (with probability 0.5), randomly applying color dithering (with probability 0.8), and randomly transforming the image. is a grayscale image (with probability 0.2), converting the image to a tensor, and image normalization operations.

test_transform is similar to train_transform, but only includes converting images into tensors and image normalization operations, which are used to preprocess test data.

3. Unsupervised learning data loading loaddataset.py

The CIFAR-10 data set is used, which contains a total of 10 categories of RGB color pictures. The size of the pictures is 32×32. There are a total of 50,000 training pictures and 10,000 test pictures in the data set.

3.1 Code

loaddataset.py : Customized data set class PreDataset, inherited from torchvision.datasets.CIFAR10.

# loaddataset.py
from torchvision.datasets import CIFAR10
from PIL import Image


class PreDataset(CIFAR10):
    def __getitem__(self, item):
        img,target=self.data[item],self.targets[item]
        img = Image.fromarray(img)

        if self.transform is not None:
            imgL = self.transform(img)
            imgR = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return imgL, imgR, target


if __name__=="__main__":

    import config
    train_data = PreDataset(root='dataset', train=True, transform=config.train_transform, download=True)
    print(train_data[0])

3.2 Detailed code explanation

3.2.1 Override the __getitem__ method

class PreDataset(CIFAR10):
    def __getitem__(self, item):
        img,target=self.data[item],self.targets[item]
        img = Image.fromarray(img)

        if self.transform is not None:
            imgL = self.transform(img)
            imgR = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return imgL, imgR, target

In PreDataset, the __getitem__ method is overridden. This method is called when retrieving a sample by subscript index. Here, it first obtains the original image data self.data[item] and the corresponding label self.targets[item].

Then, uses PIL.Image.fromarray to convert the original image data into a PIL image object img.

Next, if self.transform is not None, it means that the data preprocessing operation was passed in when creating the dataset object, then the image will be preprocessed. The preprocessing operations here are applied to two identical images, stored in imgL and imgR respectively. The purpose of this design is to use two identical images for training for subsequent contrastive learning tasks.

Finally, if self.target_transform is not None, it means that the preprocessing operation of the target label is passed in when creating the dataset object, then the label will be preprocessed.

Finally, imgL, imgR and target are returned as the contents of the sample.

3.2.2 main method

if __name__=="__main__":
    import config
    train_data = PreDataset(root='dataset', train=True, transform=config.train_transform, download=True)
    print(train_data[0])

in if name==“main”: Under the conditions, the previously defined config.py configuration file was imported. Then, a PreDataset object train_data is created, and relevant parameters are passed in, including the root directory of the data set, whether it is the training set train, and the data preprocessing operation transform, etc. Finally, the contents of the first sample are printed.

The function of this code is to define a custom data set class PreDataset and instantiate this class object in the main function for loading and processing the data set. By overriding the __getitem__ method, you can implement preprocessing operations on images and labels, and return preprocessed samples.

4. Unsupervised training: trainstage1.py

Train a model by using the self-supervised learning method (SimCLR). It loads the training data set, defines the model, loss function, and optimizer, and then updates the model parameters by looping over the training data batches. Loss values ​​during training are recorded and saved to a file.

4.1 Code

# trainstage1.py
import torch,argparse,os
import net,config,loaddataset


# train stage one
def train(args):
    if torch.cuda.is_available() and config.use_gpu:
        DEVICE = torch.device("cuda:" + str(config.gpu_name))
        # 每次训练计算图改动较小使用,在开始前选取较优的基础算法(比如选择一种当前高效的卷积算法)
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
    print("current deveice:", DEVICE)

    train_dataset=loaddataset.PreDataset(root='dataset', train=True, transform=config.train_transform, download=True)
    train_data=torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size, shuffle=True, num_workers=16 , drop_last=True)

    model =net.SimCLRStage1().to(DEVICE)
    lossLR=net.Loss().to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)

    os.makedirs(config.save_path, exist_ok=True)
    for epoch in range(1,args.max_epoch+1):
        model.train()
        total_loss = 0
        for batch,(imgL,imgR,labels) in enumerate(train_data):
            imgL,imgR,labels=imgL.to(DEVICE),imgR.to(DEVICE),labels.to(DEVICE)

            _, pre_L=model(imgL)
            _, pre_R=model(imgR)

            loss=lossLR(pre_L,pre_R,args.batch_size)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print("epoch", epoch, "batch", batch, "loss:", loss.detach().item())
            total_loss += loss.detach().item()

        print("epoch loss:",total_loss/len(train_dataset)*args.batch_size)

        with open(os.path.join(config.save_path, "stage1_loss.txt"), "a") as f:
            f.write(str(total_loss/len(train_dataset)*args.batch_size) + " ")

        if epoch % 5==0:
            torch.save(model.state_dict(), os.path.join(config.save_path, 'model_stage1_epoch' + str(epoch) + '.pth'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train SimCLR')
    parser.add_argument('--batch_size', default=200, type=int, help='')
    parser.add_argument('--max_epoch', default=1000, type=int, help='')

    args = parser.parse_args()
    train(args)

4.2 Detailed code explanation

4.2.1 Set available GPUs

def train(args):
    if torch.cuda.is_available() and config.use_gpu:
        DEVICE = torch.device("cuda:" + str(config.gpu_name))
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
    print("current deveice:", DEVICE)

This defines a function named train. It accepts one parameterargs, which is the training parameter obtained through command line parsing.

Inside the function, first check whether the GPU is available and the use of GPU is set in the configuration file (config.use_gpu). If the conditions are met, the DEVICE device is set as an available GPU device. At the same time, enable CuDNN to automatically find the most suitable convolution algorithm for the current hardware to improve performance by setting torch.backends.cudnn.benchmark = True. If the conditions are not met, set the DEVICE device as the CPU device.

4.2.2 Loading the data set

train_dataset=loaddataset.PreDataset(root='dataset', train=True, transform=config.train_transform, download=True)
train_data=torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size, shuffle=True, num_workers=16 , drop_last=True)

A training data set object train_dataset is created here. Load the data set by calling the loaddataset.PreDataset class. The parameters passed in include the root directory of the data set, whether it is a training set train, the data preprocessing operation transform, etc.

Then,use torch.utils.data.DataLoader to encapsulate the training data set into an iterable data loader train_data. Set the batch size batch_size, whether to disrupt the data shuffle, the number of threads used num_workers and other parameters.

4.2.3 Create training model loss function and optimizer

model = net.SimCLRStage1().to(DEVICE)
lossLR = net.Loss().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)

Create an unsupervised SimCLR model model, instantiate it by calling the net.SimCLRStage1 class. The model is moved to the previously determined device DEVICE for training.

In addition, creates the loss function lossLR, and instantiates it by calling the net.Loss class. The loss function has also been moved to the device DEVICE.

Finally, define the optimizer and use the Adam optimizer to optimize the parameters of the model. Pass the model parameters to the optimizer and set parameters such as learning rate lr and weight decay weight_decay.

4.2.4 Save training process files

os.makedirs(config.save_path, exist_ok=True)

This line of code is used to create a folder to save the model and the results during the training process.
config.save_path is the save path defined in the config module.

4.2.5 Use for loop to load the training process of each batch

for epoch in range(1,args.max_epoch+1):
    model.train()
    total_loss = 0
    for batch,(imgL,imgR,labels) in enumerate(train_data):
        imgL,imgR,labels=imgL.to(DEVICE),imgR.to(DEVICE),labels.to(DEVICE)

        _, pre_L=model(imgL)
        _, pre_R=model(imgR)

        loss=lossLR(pre_L,pre_R,args.batch_size)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print("epoch", epoch, "batch", batch, "loss:", loss.detach().item())
        total_loss += loss.detach().item()

    print("epoch loss:",total_loss/len(train_dataset)*args.batch_size)
# 把过程数据写入日志文件stage1_loss.txt中
    with open(os.path.join(config.save_path, "stage1_loss.txt"), "a") as f:
        f.write(str(total_loss/len(train_dataset)*args.batch_size) + " ")

    if epoch % 5==0:
        torch.save(model.state_dict(), os.path.join(config.save_path, 'model_stage1_epoch' + str(epoch) + '.pth'))

This is the main loop of training. For each training cycle (epoch), the model is set to training mode (model.train()). Then, useenumerate(train_data) to iterate over the batches in the training data loader.

In each batch, move data to the device DEVICE. Forward propagation of the left and right images through the model model, to obtain the prediction results pre_L and pre_R. Then, calculate the loss value loss by calling the lossLR loss function, passing the prediction result and batch size args.batch_size as parameters.

Next, perform the optimization steps. First, zero the optimizer's gradient buffer (optimizer.zero_grad()). Then, calculate the gradient of the loss value with respect to the model parameters (loss.backward()). Finally, the optimizer's step() method is called to update the model parameters.

After each batch, print out the current training epoch, batch, and loss values. The total loss value is accumulated to calculate the average loss value for each training epoch.

After completing a training cycle, print out the average loss value for that cycle. Then, write the average loss value to the stage1_loss.txt file under the save path.

If the current period is a multiple of 5, save the model's state dictionary to a file with the file name containing the training period information.

4.2.6 Set command line parameters

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train SimCLR')
    parser.add_argument('--batch_size', default=200, type=int, help='')
    parser.add_argument('--max_epoch', default=1000, type=int, help='')

    args = parser.parse_args()
    train(args)

In this code block, an argparse.ArgumentParser object is first created for parsing command line parameters. The description parameter is a descriptive string used to generate help documentation. Then, two command line parameters were added via the add_argument method: –batch_size and –max_epoch. The default parameter specifies the default value of the parameter, the type parameter specifies the type of the parameter, and the help parameter is an optional help text used to describe the function of the parameter. This defines two parameters that can be specified via the command line.

Next, call the parser.parse_args() method to parse the command line parameters and assign the parsed results to the variable args. In this way, the parameter values ​​specified on the command line can be accessed through args.batch_size and args.max_epoch.

Finally, call the train(args) function and pass the parsed parameters for training. This wraps the training process in an executable script, and you can run the training process by specifying parameters on the command line.

5. Supervised training stage: trainstage2.py

5.1 Code

# trainstage2.py
import torch,argparse,os
import net,config
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader


# train stage two
def train(args):

# 检查是否可用GPU,并根据配置文件中的use_gpu参数和GPU的可用性确定设备类型。如果可用,还会对CUDA加速进行一些配置。
    if torch.cuda.is_available() and config.use_gpu:
        DEVICE = torch.device("cuda:" + str(2))   #config.gpu_name
        # 每次训练计算图改动较小使用,在开始前选取较优的基础算法(比如选择一种当前高效的卷积算法)
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
    print("current deveice:", DEVICE)

    # load dataset for train and eval
    train_dataset = CIFAR10(root='dataset', train=True, transform=config.train_transform, download=True)
    train_data = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True)
    eval_dataset = CIFAR10(root='dataset', train=False, transform=config.test_transform, download=True)
    eval_data = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True)

    model =net.SimCLRStage2(num_class=len(train_dataset.classes)).to(DEVICE)
    # 加载预训练模型的参数到模型中,使用torch.load函数加载参数文件,并通过model.load_state_dict方法将参数加载到模型中。
    # args.pre_model是命令行参数--pre_model指定的预训练模型的路径。
    model.load_state_dict(torch.load(args.pre_model, map_location='cpu'),strict=False)
    # 损失函数定义,使用交叉熵损失
    loss_criterion = torch.nn.CrossEntropyLoss()
    # 优化器定义
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6)
# 创建一个用于保存模型和结果的文件夹
    os.makedirs(config.save_path, exist_ok=True)

    for epoch in range(1,args.max_epoch+1):
        model.train()
        total_loss=0
        for batch, (data, target) in enumerate(train_data):
            data, target = data.to(DEVICE), target.to(DEVICE)
            pred = model(data)

            loss = loss_criterion(pred, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print("epoch",epoch,"loss:", total_loss / len(train_dataset)*args.batch_size)
        with open(os.path.join(config.save_path, "stage2_loss.txt"), "a") as f:
            f.write(str(total_loss / len(train_dataset)*args.batch_size) + " ")

        if epoch % 5==0:
            torch.save(model.state_dict(), os.path.join(config.save_path, 'model_stage2_epoch' + str(epoch) + '.pth'))

            model.eval()
            with torch.no_grad():
                print("batch", " " * 1, "top1 acc", " " * 1, "top5 acc")
                total_loss, total_correct_1, total_correct_5, total_num = 0.0, 0.0, 0.0, 0
                for batch, (data, target) in enumerate(train_data):
                    data, target = data.to(DEVICE), target.to(DEVICE)
                    pred = model(data)

                    total_num += data.size(0)
                    prediction = torch.argsort(pred, dim=-1, descending=True)
                    top1_acc = torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                    top5_acc = torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                    total_correct_1 += top1_acc
                    total_correct_5 += top5_acc

                    print("  {:02}  ".format(batch + 1), " {:02.3f}%  ".format(top1_acc / data.size(0) * 100),
                          "{:02.3f}%  ".format(top5_acc / data.size(0) * 100))

                print("all eval dataset:", "top1 acc: {:02.3f}%".format(total_correct_1 / total_num * 100),
                          "top5 acc:{:02.3f}%".format(total_correct_5 / total_num * 100))
                with open(os.path.join(config.save_path, "stage2_top1_acc.txt"), "a") as f:
                    f.write(str(total_correct_1 / total_num * 100) + " ")
                with open(os.path.join(config.save_path, "stage2_top5_acc.txt"), "a") as f:
                    f.write(str(total_correct_5 / total_num * 100) + " ")

# 判断当前脚本是否作为主程序直接运行。如果是,则解析命令行参数,并调用train函数进行训练。
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train SimCLR')
    parser.add_argument('--batch_size', default=200, type=int, help='')
    parser.add_argument('--max_epoch', default=200, type=int, help='')
    parser.add_argument('--pre_model', default=config.pre_model, type=str, help='')

    args = parser.parse_args()
    train(args)

5.2 Detailed code explanation

5.2.1 Loading the data set

train_dataset = CIFAR10(root='dataset', train=True, transform=config.train_transform, download=True)
    train_data = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True)
    eval_dataset = CIFAR10(root='dataset', train=False, transform=config.test_transform, download=True)
    eval_data = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True)

This code is used to load the training set and validation set data. CIFAR10 is an image classification data set. Set the path of the data set by specifying the root parameter. train=True means loading the training set, train=False means loading the verification set. The transform parameter specifies the preprocessing operation of the data. download=True means that if the data set does not exist, download the data set.

DataLoader is used to encapsulate data into an iterable data loader. The batch_size parameter specifies the number of samples in each batch, the shuffle parameter indicates whether the data is shuffled before each epoch, the num_workers parameter indicates the number of threads used for data loading, and the pin_memory parameter indicates whether the data is stored in page-locked memory. , to speed up data transfer.

5.2.2 Creating supervised modules

model =net.SimCLRStage2(num_class=len(train_dataset.classes)).to(DEVICE)
    model.load_state_dict(torch.load(args.pre_model, map_location='cpu'),strict=False)
    loss_criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6)

This code is used to create the model, load the parameters of the pretrained model, and define the loss function and optimizer.

net.SimCLRStage2 is a custom model class used for the second stage of training. Get the number of categories in the training set through len(train_dataset.classes) and pass it to the model as the output category number.

The torch.load function is used to load the parameters of the pre-trained model. args.pre_model is the path to the pre-trained model specified by the command line parameter –pre_model. map_location='cpu' means loading the model onto the CPU.

torch.nn.CrossEntropyLoss is the cross-entropy loss function, used for multi-classification problems.

torch.optim.Adam is the Adam optimizer, used to optimize the parameters of the model. model.fc.parameters() specifies the parameters to be optimized, lr=1e-3 means the learning rate is 0.001, and weight_decay=1e-6 means the weight decay parameter.

5.2.3 Main loop of training and validation

for epoch in range(1,args.max_epoch+1):
        model.train()
        total_loss=0
        for batch, (data, target) in enumerate(train_data):
            data, target = data.to(DEVICE), target.to(DEVICE)
            pred = model(data)

            loss = loss_criterion(pred, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print("epoch",epoch,"loss:", total_loss / len(train_dataset)*args.batch_size)
        with open(os.path.join(config.save_path, "stage2_loss.txt"), "a") as f:
            f.write(str(total_loss / len(train_dataset)*args.batch_size) + " ")

        if epoch % 5==0:
            torch.save(model.state_dict(), os.path.join(config.save_path, 'model_stage2_epoch' + str(epoch) + '.pth'))

            model.eval()
            with torch.no_grad():
                print("batch", " " * 1, "top1 acc", " " * 1, "top5 acc")
                total_loss, total_correct_1, total_correct_5, total_num = 0.0, 0.0, 0.0, 0
                for batch, (data, target) in enumerate(train_data):
                    data, target = data.to(DEVICE), target.to(DEVICE)
                    pred = model(data)

                    total_num += data.size(0)
                    prediction = torch.argsort(pred, dim=-1, descending=True)
                    top1_acc = torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                    top5_acc = torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                    total_correct_1 += top1_acc
                    total_correct_5 += top5_acc

                    print("  {:02}  ".format(batch + 1), " {:02.3f}%  ".format(top1_acc / data.size(0) * 100),
                          "{:02.3f}%  ".format(top5_acc / data.size(0) * 100))

                print("all eval dataset:", "top1 acc: {:02.3f}%".format(total_correct_1 / total_num * 100),
                          "top5 acc:{:02.3f}%".format(total_correct_5 / total_num * 100))
                with open(os.path.join(config.save_path, "stage2_top1_acc.txt"), "a") as f:
                    f.write(str(total_correct_1 / total_num * 100) + " ")
                with open(os.path.join(config.save_path, "stage2_top5_acc.txt"), "a") as f:
                    f.write(str(total_correct_5 / total_num * 100) + " ")

In each epoch, the model is set to training mode model.train(), and iterates over the training data set.

In each batch, data and labels are moved to designated devices and predicted values ​​are calculated by the model. Then the loss is calculated and backpropagation and parameter updates are performed. Calculate the predicted value (pred) through the model, and then use the loss function (loss_criterion) to calculate the loss between the predicted value and the true label. The zero_grad() method of the optimizer is used to clear the gradient, the backward() method is used to calculate the gradient, and the step() method is used to update the model parameters.

At the end of each epoch, print the average loss and write the loss value to a file (stage2_loss.txt) for subsequent analysis and visualization.

If the current epoch is a multiple of 5, save the model parameters to the file (model_stage2_epoch{epoch}.pth) for subsequent use.

At the end of each epoch, the model is set toevaluation mode model.eval() and closed using torch.no_grad() Gradient calculation, iterative validation data set. In each batch, the top-1 and top-5 accuracy rates are calculated, and the number of correct samples and the total number of samples are accumulated. After completing the evaluation on all datasets, print the overall top-1 and top-5 accuracy and write them to files (stage2_top1_acc.txt and stage2_top5_acc.txt) for subsequent analysis and visualization.

ps: What are top-k evaluation criteria?

In computer vision tasks, one of the commonly used evaluation metrics is top-k accuracy, where k represents the ranking of the prediction results. In this case, top-1 accuracy represents the proportion of the model's predictions for the highest probability category that matches the true label, i.e., only the highest-ranking predictions are considered. The top-5 accuracy indicates whether the top five categories in the model's prediction results include the proportion of real labels.

Specifically, in the evaluation section of the code, for each sample, the model generates a vector of predictions that contains the probability score for each class. Then, based on these probability scores, the categories are sorted from high to low scores. The top-1 accuracy calculates whether the highest-ranked category in the prediction results matches the real label, while the top-5 accuracy calculates whether the top five categories in the prediction results include Real label.

For example, for an image classification task, if the highest-ranked category in the model's predictions matches the true label, then its top-1 accuracy is 1. If at least one of the top five categories in the model's predictions matches the true label, then its top-5 accuracy is 1.

These accuracy metrics can help us understand the performance of the model in classification tasks, especially in multi-class classification problems. Top-1 accuracy is often considered the primary evaluation metric, while top-5 accuracy can provide a more relaxed evaluation, allowing the model to have some ambiguity or uncertainty in its prediction results.

6. Training and viewing process

Auxiliary script to visualize loss and accuracy data during training. The Visdom library was used to create interactive charts. The visdom library can be installed using pip install

6.1 Code

# showbyvisdom.py
import numpy as np
import visdom


def show_loss(path, name, step=1):
    with open(path, "r") as f:
        data = f.read()
    data = data.split(" ")[:-1]
    x = np.linspace(1, len(data) + 1, len(data)) * step
    y = []
    for i in range(len(data)):
        y.append(float(data[i]))

    vis = visdom.Visdom(env='loss')
    vis.line(X=x, Y=y, win=name, opts={
    
    'title': name, "xlabel": "epoch", "ylabel": name})


def compare2(path_1, path_2, title="xxx", legends=["a", "b"], x="epoch", step=20):
    with open(path_1, "r") as f:
        data_1 = f.read()
    data_1 = data_1.split(" ")[:-1]

    with open(path_2, "r") as f:
        data_2 = f.read()
    data_2 = data_2.split(" ")[:-1]

    x = np.linspace(1, len(data_1) + 1, len(data_1)) * step
    y = []
    for i in range(len(data_1)):
        y.append([float(data_1[i]), float(data_2[i])])

    vis = visdom.Visdom(env='loss')
    vis.line(X=x, Y=y, win="compare",
             opts={
    
    "title": "compare " + title, "legend": legends, "xlabel": "epoch", "ylabel": title})


if __name__ == "__main__":
    show_loss("stage1_loss.txt", "loss1")
    show_loss("stage2_loss.txt", "loss2")
    show_loss("stage2_top1_acc.txt", "acc1")
    show_loss("stage2_top5_acc.txt", "acc1")

    # compare2("precision1.txt", "precision2.txt", title="precision", step=20)

6.2 Detailed code explanation

def show_loss(path, name, step=1):
    with open(path, "r") as f:
        data = f.read()
    data = data.split(" ")[:-1]
    x = np.linspace(1, len(data) + 1, len(data)) * step
    y = []
    for i in range(len(data)):
        y.append(float(data[i]))

    vis = visdom.Visdom(env='loss')
    vis.line(X=x, Y=y, win=name, opts={
    
    'title': name, "xlabel": "epoch", "ylabel": name})

The show_loss function is used to display changes in the loss function. It accepts three parameters: path represents the file path to store the data, name represents the name of the window displaying the image, step represents the abscissa step size, and the default is 1.

Inside the function, the open function is first used to open the file with the specified path and read the file contents. Then use the split function to split the read content by spaces and remove the last empty element ([:-1]). This converts the data in the file into a list of strings.

Next, use the np.linspace function to generate an abscissa array x with the same length as the data, and multiply it by the step size. Then create an empty list y.

Then use a loop to iterate through the data list, convert each string element to a floating point number, and add it to the y list.

Then, create a visdom.Visdom object and specify the environment as 'loss'. Then use the vis.line function to draw a line chart, where the X parameter is the abscissa array x, the Y parameter is the ordinate array y, the win parameter is the window name, and the opts parameter is the image title, abscissa label, ordinate label and other options .

7. Validation set evaluation: eval.py

7.1 Code

# eval.py
import torch,argparse
from torchvision.datasets import CIFAR10
import net,config


def eval(args):
    if torch.cuda.is_available() and config.use_gpu:
        DEVICE = torch.device("cuda:" + str(config.gpu_name))
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")

    eval_dataset=CIFAR10(root='dataset', train=False, transform=config.test_transform, download=True)
    eval_data=torch.utils.data.DataLoader(eval_dataset,batch_size=args.batch_size, shuffle=False, num_workers=16, )

    model=net.SimCLRStage2(num_class=len(eval_dataset.classes)).to(DEVICE)
    model.load_state_dict(torch.load(config.pre_model, map_location='cpu'), strict=False)

    # total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(eval_data)
    total_correct_1, total_correct_5, total_num = 0.0, 0.0, 0.0

    model.eval()
    with torch.no_grad():
        print("batch", " "*1, "top1 acc", " "*1,"top5 acc" )
        for batch, (data, target) in enumerate(eval_data):
            data, target = data.to(DEVICE) ,target.to(DEVICE)
            pred=model(data)

            total_num += data.size(0)
            prediction = torch.argsort(pred, dim=-1, descending=True)
            top1_acc = torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            top5_acc = torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            total_correct_1 += top1_acc
            total_correct_5 += top5_acc

            print("  {:02}  ".format(batch+1)," {:02.3f}%  ".format(top1_acc / data.size(0) * 100),"{:02.3f}%  ".format(top5_acc / data.size(0) * 100))

        print("all eval dataset:","top1 acc: {:02.3f}%".format(total_correct_1 / total_num * 100), "top5 acc:{:02.3f}%".format(total_correct_5 / total_num * 100))



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='test SimCLR')
    parser.add_argument('--batch_size', default=512, type=int, help='')

    args = parser.parse_args()
    eval(args)

7.2 Detailed code explanation

7.2.1 Loading the evaluation data set

eval_dataset=CIFAR10(root='dataset', train=False, transform=config.test_transform, download=True)
    eval_data=torch.utils.data.DataLoader(eval_dataset,batch_size=args.batch_size, shuffle=False, num_workers=16, )

Load the test set of the CIFAR10 dataset. The root parameter specifies the root directory of the data set, train=False means loading the test set, transform=config.test_transform means using the test set data conversion function defined in the configuration file, and download=True means downloading the data set if it does not exist.

Then use torch.utils.data.DataLoader to create a data loader eval_data for batch loading of test data. The batch_size parameter specifies the batch size, shuffle=False means that the data will not be shuffled, and the num_workers parameter specifies the number of threads used for data loading.

7.2.2 Creating a classifier model

model=net.SimCLRStage2(num_class=len(eval_dataset.classes)).to(DEVICE)
model.load_state_dict(torch.load(config.pre_model, map_location='cpu'), strict=False)

Create a net.SimCLRStage2 model object model for evaluation. SimCLRStage2 is a custom model class used for image classification. The num_class parameter is set to the number of classes of the evaluation dataset.

Then, use the torch.load function to load the parameters of the pre-trained model and load them into the model. config.pre_model specifies the path to the pre-trained model. map_location='cpu' means loading the model parameters to the CPU when there is no GPU. strict=False means allowing the loading of parameters that do not strictly match.

7.2.3 Verification

total_correct_1, total_correct_5, total_num = 0.0, 0.0, 0.0

    model.eval()
    with torch.no_grad():
        print("batch", " "*1, "top1 acc", " "*1,"top5 acc" )
        for batch, (data, target) in enumerate(eval_data):
            data, target = data.to(DEVICE) ,target.to(DEVICE)
            pred=model(data)

            total_num += data.size(0)
            prediction = torch.argsort(pred, dim=-1, descending=True)
            top1_acc = torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            top5_acc = torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            total_correct_1 += top1_acc
            total_correct_5 += top5_acc

            print("  {:02}  ".format(batch+1)," {:02.3f}%  ".format(top1_acc / data.size(0) * 100),"{:02.3f}%  ".format(top5_acc / data.size(0) * 100))

        print("all eval dataset:","top1 acc: {:02.3f}%".format(total_correct_1 / total_num * 100), "top5 acc:{:02.3f}%".format(total_correct_5 / total_num * 100))

Initialize the variables total_correct_1, total_correct_5 and total_num to 0.

Then, set the model to evaluation mode, disable gradient calculation, and use torch.nograd();

In a loop, iterate over each batch in the evaluation data loader. Move the batch's input data and target labels to the device.

Through forward propagation of the model, the prediction result pred is obtained. Use the torch.argsort function to sort the prediction results to obtain a class index from high to low probability.

Calculate the Top-1 and Top-5 accuracy of each sample. First, the prediction results are compared with the target labels to get a Boolean value of whether each sample is in Top-1 or Top-5. Then the Boolean values ​​are summed and converted into floating point numbers, and finally the accuracy value is obtained through the item() method.

Accumulate the number of correct predictions and the total number of samples in each batch. Print the current batch number, Top-1 accuracy and Top-5 accuracy in each batch. After the loop ends, print the Top-1 accuracy and Top-5 accuracy of the entire evaluation data set.

7.2.4 main

Create an argument parser argparse.ArgumentParser for parsing arguments from the command line. Among them –batch_size is an optional parameter with a default value of 512, which is used to specify the batch size during evaluation.

Parses the command line arguments and passes the arguments to the eval function for evaluation.

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='test SimCLR')
    parser.add_argument('--batch_size', default=512, type=int, help='')

    args = parser.parse_args()
    eval(args)

8. Customized picture test

Used to classify an image using a trained model.

8.1 Code

# test.py
import torch,argparse
import net,config
from torchvision.datasets import CIFAR10
import cv2


def show_CIFAR10(index):
    eval_dataset=CIFAR10(root='dataset', train=False, download=False)
    print(eval_dataset.__len__())
    print(eval_dataset.class_to_idx,eval_dataset.classes)
    img, target=eval_dataset[index][0], eval_dataset[index][1]

    import matplotlib.pyplot as plt
    plt.figure(str(target))
    plt.imshow(img)
    plt.show()


def test(args):
    classes={
    
    'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
    index2class=[x  for x in classes.keys()]
    print("calss:",index2class)

    if torch.cuda.is_available() and config.use_gpu:
        DEVICE = torch.device("cuda:" + str(config.gpu_name))
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")

    transform = config.test_transform

    ori_img=cv2.imread(args.img_path,1)
    img=cv2.resize(ori_img,(32,32)) # evry important,influence the result

    img=transform(img).unsqueeze(dim=0).to(DEVICE)

    model=net.SimCLRStage2(num_class=10).to(DEVICE)
    model.load_state_dict(torch.load(args.pre_model, map_location='cpu'), strict=False)

    pred = model(img)

    prediction = torch.argsort(pred, dim=-1, descending=True)

    label=index2class[prediction[:, 0:1].item()]
    cv2.putText(ori_img,"this is "+label,(30,30),cv2.FONT_HERSHEY_DUPLEX,1, (0,255,0), 1)
    cv2.imshow(label,ori_img)
    cv2.waitKey(0)


if __name__ == '__main__':
    # show_CIFAR10(2)

    parser = argparse.ArgumentParser(description='test SimCLR')
    parser.add_argument('--pre_model', default=config.pre_model, type=str, help='')
    parser.add_argument('--img_path', default="bird.jpg", type=str, help='')

    args = parser.parse_args()
    test(args)

8.2 Detailed code explanation

8.2.1 Create a test set and obtain images

def show_CIFAR10(index):
    eval_dataset=CIFAR10(root='dataset', train=False, download=False)
    print(eval_dataset.__len__())
    print(eval_dataset.class_to_idx,eval_dataset.classes)
    img, target=eval_dataset[index][0], eval_dataset[index][1]
    
    import matplotlib.pyplot as plt
    plt.figure(str(target))
    plt.imshow(img)
    plt.show()

The show_CIFAR10 function is used to display images at the specified index in the CIFAR10 dataset. First create a CIFAR10 data set object eval_dataset, where the root parameter specifies the root directory of the data set, train=False means using the test set, and download=False means not downloading the data set.

Then, obtain the image and target label at the specified index through index index. Use the matplotlib.pyplot library to create an image window that displays the image.

8.2.2 Preprocessing images

The test function is used to classify images.

def test(args):
    classes={
    
    'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
    index2class=[x  for x in classes.keys()]
    print("calss:",index2class)

	if torch.cuda.is_available() and config.use_gpu:
        DEVICE = torch.device("cuda:" + str(config.gpu_name))
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
       
    model=net.SimCLRStage2(num_class=10).to(DEVICE)
    model.load_state_dict(torch.load(args.pre_model, map_location='cpu'), strict=False)

    pred = model(img)

    prediction = torch.argsort(pred, dim=-1, descending=True)

    label=index2class[prediction[:, 0:1].item()]
    cv2.putText(ori_img,"this is "+label,(30,30),cv2.FONT_HERSHEY_DUPLEX,1, (0,255,0), 1)
    cv2.imshow(label,ori_img)
    cv2.waitKey(0)

First define a dictionary classes to map category names to category indexes. Then get the list of category names index2class by the key of the dictionary.

Get the transformation function transform for test data preprocessing in the configuration file.

Use the cv2.imread function to read the image of the specified path args.img_path, and specify parameter 1 to read in color image format. Resize the original image ori_img to (32,32) and save it to the variable img. This step is very important because the image size used in model training is 32x32 pixels.

Apply the transformation function transform of the test data preprocessing to preprocess the image and add a dimension to the 0th dimension to match the shape of the model input. Move the processed image to the device.

8.2.3 Creating a classification model

model=net.SimCLRStage2(num_class=10).to(DEVICE)
    model.load_state_dict(torch.load(args.pre_model, map_location='cpu'), strict=False)

    pred = model(img)

    prediction = torch.argsort(pred, dim=-1, descending=True)

    label=index2class[prediction[:, 0:1].item()]
    cv2.putText(ori_img,"this is "+label,(30,30),cv2.FONT_HERSHEY_DUPLEX,1, (0,255,0), 1)
    cv2.imshow(label,ori_img)
    cv2.waitKey(0)

Create a net.SimCLRStage2 model object model for classification. SimCLRStage2 is a custom model class used for image classification. The num_class parameter is set to 10, corresponding to the number of categories in the CIFAR10 data set.

To load the pre-trained model parameters, use the torch.load function to load the parameter file args.pre_model, and pass in the map_location=’cpu’ parameter to ensure that the model can be loaded without a GPU. strict=False means that mismatches are allowed when loading model parameters.

Input the image img into the model and get the prediction result pred.

Use the torch.argsort function to sort the prediction results pred from high to low according to probability to obtain the sorted index.

According to the sorted index, the category index with the highest probability is obtained, and the category label label is obtained through index2class dictionary mapping.

Use the cv2.putText function to add text labels to the original image and display the prediction results.

Finally, use the cv2.imshow function to display the image with the prediction results, and wait for the user to press any key to close the image window through cv2.waitKey(0).

8.2.4 main

In the main program of the script, create an argparse.ArgumentParser object for parsing command line parameters. Two command line parameters -pre_model and -img_path are defined, which represent the pre-training model parameter file and input image path respectively.

Parse the command line parameters and pass the parsing results to the test function for image classification.

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='test SimCLR')
    parser.add_argument('--pre_model', default=config.pre_model, type=str, help='')
    parser.add_argument('--img_path', default="bird.jpg", type=str, help='')

    args = parser.parse_args()
    test(args)

Guess you like

Origin blog.csdn.net/weixin_45662399/article/details/134939886