Some functions and related configuration files built by image classification network (1)

1.parser = argparse.ArgumentParser()

argparse is python's standard module for parsing command-line arguments and options, and is used to replace the obsolete optparse module. The argparse module is used to parse command line arguments.

Many times, we need to use a program that parses command line parameters. The purpose is to input training parameters and options in the terminal window (ubuntu is the terminal window, and windows is the command line window).

Steps for usage

We can often simplify the use of argparse into the following four steps

1:import argparse

2:parser = argparse.ArgumentParser()

3:parser.add_argument()

4:parser.parse_args()  

The above four steps are explained as follows: first import the module; then create a parsing object; then add the command line parameters and options you want to pay attention to to the object, each add_argument method corresponds to a parameter or option you want to pay attention to; finally call parse_args() method to parse; it can be used after successful parsing.

2. Define the dataset model

import torch  
from torch.utils.data import Dataset  
from PIL import Image  
from torchvision import transforms  

class Mydataset(Dataset):  
    """自定义数据集"""  
    def __init__(self,images_path,images_class,transform=None):  
        self.images_path = images_path                   #图像路径  
        self.images_class = images_class                 #图像种类  
        self.transform = transform                       #数据预处理  
  
  
    def __getitem__(self, index):  
        img = Image.open(self.images_path[index])  
        if img.mode != 'RGB' :  
            raise ValueError("image:{} isn't RGB mode.".format(self.images_path[index])) #若不是RGB图像抛出异常  
        label = self.images_class[index]  
        if self.transform is not None:  
            img = self.transform(img)  
  
        return img,label  
  
    def __len__(self):  
        return len(self.images_path)  
  
    def collatr_fn(batch):  
        images,labels = tuple(zip(*batch))  
        images = torch.stack(images,dim=0)  
        labels = torch.as_tensor(labels)  
        return images,labels

The above code can be created into a .py file named my_dataset for easy calling.

The code in the above my_datase.py file roughly means: define a custom dataset class, obtain the image path, image type, and image preprocessing method in the custom dataset, and finally return the path of the image and the corresponding image through a series of operations Label.

3. Define the image preprocessing method and instantiation in the dataset

images_size = 224  
data_transform = {  
    "train":transforms.Compose([transforms.RandomResizedCrop(images_size),                          #先随机采集,然后对裁剪得到的图像缩放为同一大小  
                                transforms.RandomHorizontalFlip(),                                  #以给定的概率随机水平旋转给定的PIL的图像,默认为0.5  
                                transforms.ToTensor(),                                              #将给定图像转为Tensor  
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#标准化,均值为0,标准差为1  
    "val":transforms.Compose([transforms.Resize(int(images_size * 1.143)),                          #将图片短边缩放至images_size*1.143,长宽比保持不变  
                              transforms.CenterCrop(images_size),                                   #将图片从中心裁剪成images_size大小  
                              transforms.ToTensor(),                                                #将给定图像转为Tensor  
                              transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} #标准化,均值为0,标准差为1  
# 实例化训练数据集  
train_dataset = Mydataset(images_path='',  
                          images_class=1,  
                          transform = data_transform["train"])  
# 实例化验证数据集  
val_dataset = Mydataset(images_path='',  
                        images_class=1,  
                        transform = data_transform["val"])  
batch_size = args.batch.size  
nw = min(os.cpu_count(),batch_size if batch_size > 1 else 0,8)  
print("Using {} dataloader workers every process".format(nw))  
  
train_loader = DataLoader(train_dataset,                         #处理好的所有数据  
                          batch_size = batch_size,               #批次数量  
                          shuffle = True,                        #打乱数据  
                          num_workers = nw,                      #加载数据的线程数  
                          collate_fn = train_dataset.collatr_fn, #batch的样本打包成一个tensor的结构  
                          pin_memory = True)                     #将加载的数据拷贝到CUDA中的固定内存中,从而使数据更快地传输到支持cuda的gpu  
val_loader = DataLoader(val_dataset,                         #处理好的所有数据  
                          batch_size = batch_size,               #批次数量  
                          shuffle = False,                        #打乱数据  
                          num_workers = nw,                      #加载数据的线程数  
                          collate_fn = val_dataset.collatr_fn, #batch的样本打包成一个tensor的结构  
                          pin_memory = True)                     #将加载的数据拷贝到CUDA中的固定内存中,从而使数据更快地传输到支持cuda的gpu

To be continued! ! !

Guess you like

Origin blog.csdn.net/weixin_42715977/article/details/129924735