Pytorch study notes 1: make and load your own image data set

Pytorch study notes 1: make and load your own image data set



Preface

First introduce how to use pytorch to load the existing data set of the network, and then introduce how to make your own image data set and read it in batches to train your own network.


Tip: The following is the content of this article, the following cases are for reference

One, download the data set

Use Pytorch to read the local MINIST data set and load it

# 训练数据和测试数据的下载
trainDataset = torchvision.datasets.MNIST( # torchvision可以实现数据集的训练集和测试集的下载
  root="./data", # 下载数据,并且存放在data文件夹中
  train=True, # train用于指定在数据集下载完成后需要载入哪部分数据,如果设置为True,则说明载入的是该数据集的训练集部分;如果设置为False,则说明载入的是该数据集的测试集部分。
  transform=transforms.ToTensor(), # 数据的标准化等操作都在transforms中,此处是转换
  download=True 
)
 
testDataset = torchvision.datasets.MNIST(
  root="./data",
  train=False,
  transform=transforms.ToTensor(),
  download=True
)

Two, load your own data set

1. Make a data set

Training a neural network requires a standard input image and its ground truth label.
In classification problems, such as cats, dogs, boats, cars, etc., we can use numbers to represent different classifications. A txt file can be made to store the address of the input image and its corresponding label number.
I have a task that needs to take an image as input and another processed image as its true value, so what I wrote below the txt text is their path. Create a new train folder under the project path for training images, and create a new training txt under the train folder to label training images and label images
Insert picture description here

2. Load the data set

Dataset class

PyTorch reads pictures mainly through the Dataset class, which is the parent class that should be inherited from all data set loading classes in Pytorch. We read our own image dataset by inheriting and rewriting the Dataset class. The following three functions must be rewritten:
__init__ method to read data files

__getitem__ method supports subscript access

The __len__ method returns the size of the custom data set to facilitate later traversal

class OpticalSARDataset(Data.Dataset):
    """
      定义自己的数据集、读取数据、初始化数据
    """

    def __init__(self, data_dir, part):
        # 所有图片的绝对路径
        assert part in ["train", "val"]
        self.image_dir = os.path.join(data_dir, part)
        self.img_names = []
        self.label_names = []

        with open(os.path.join(data_dir, part, "label.txt")) as f:
            while True:
                il = f.readline(1500)  # 如果样本数据名称大于1500,修改该值
                if not il:
                    break
                a = il.split(sep=' ')
                self.img_names.append(a[0])
                self.label_names.append(a[1][0:-1])  # remove '\n'
        self.samples_num = len(self.img_names)
        # print(self.samples_num)

        self.transform = torchvision.transforms.Compose([
            # 将 PIL 图片转换成位于[0.0, 1.0]的floatTensor, shape (C x H x W)
            torchvision.transforms.ToTensor()])

    def __len__(self):
        # 返回图像的数量
        return self.samples_num

    def __getitem__(self, index):
        tp_img = Image.open(os.path.join(self.image_dir,  self.img_names[index])
                                 ).convert('RGB')
        tp_label = Image.open(os.path.join(self.image_dir, self.label_names[index])
                               ).convert('RGB')
        # PIL.Image.open 读取的图片数据是RGB格式;
        tp_img = cv2.cvtColor(np.asarray(tp_img), cv2.COLOR_RGB2BGR)
        tp_label = cv2.cvtColor(np.asarray(tp_label), cv2.COLOR_RGB2BGR) # 转换为BGR便于cv2.imshow,跟下面imshow之前RGB2BGR只用一种方法,这里统一为cv2的BGR格式
        img = self.transform(tp_img)
        label = self.transform(tp_label)


        sample = {
    
    
            "label": label,  # shape
            "image": img  # shape: (3, *image_size)
        }


        return sample

Define the data set

# 利用之前创建好的OpticalSARDataset类去创建数据对象
train_dataset = OpticalSARDataset(data_dir, 'train')  # 训练数据集

Dataloader class

The Dataset class mentioned before reads in the data set and indexes the read data.
But this function alone is not enough. In the actual process of loading the data set, our data volume is often very large. For this we need several functions: it
can be read in batches: batch-size
can Random reading of data, data can be shuffling operation (shuffling), disrupting the order of data distribution in the data set, data
can be loaded in parallel (using multi-core processors to speed up the efficiency of loading data) The
Dataloader class does not require us Design your own code, we only need to use the DataLoader class to read the class we designed.

Instantiate data set

# 利用dataloader读取我们的数据对象,并设定batch-size和工作现场
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=0)
batch = iter(train_iter).next()
print(batch["image"].shape, batch["label"].shape)
print(batch["image"][0].shape)

to sum up

Reference blog:
define your own data set
pytorch load your own data set
design your own data
train your own data complete steps
Dataset class



Guess you like

Origin blog.csdn.net/qq_43173239/article/details/108948228