pytorch image processing: read dataset Dataset and ImageFolder

 1. Rewrite the Dataset class:

#源码
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
 
#这个函数就是根据索引,迭代的读取路径和标签。因此我们需要有一个路径和标签的 ‘容器’供我们读
def __getitem__(self, index):
	raise NotImplementedError
 
#返回数据的长度
def __len__(self):
	raise NotImplementedError
def __add__(self, other):
	return ConcatDataset([self, other])

If you want to make your own image dataset for DataLoader to take, you must first rewrite the Datasets class, which is mainly used to complete the function of where to read data and labels. The main ones are __getitem()__ (returns the dataset and label) and __len__ (returns the length of the data).

After completing these two main functions of the Datasets class, you can transfer the dataset to DataLoader during training to obtain the batch data you want.

Example 1: Reading through a TXT file containing data paths and tags

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
 
#集成Dataset类
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
    """
    tex_path : txt文本路径,该文本包含了图像的路径信息,以及标签信息
    transform:数据处理,对图像进行随机剪裁,以及转换成tensor
    """
	fh = open(txt_path, 'r')  #读取文件
	imgs = []  #用来存储路径与标签
    #一行一行的读取
	for line in fh:
		line = line.rstrip()  #这一行就是图像的路径,以及标签  
        
		words = line.split()
		imgs.append((words[0], int(words[1])))  #路径和标签添加到列表中
		self.imgs = imgs                        
		self.transform = transform
		self.target_transform = target_transform
 
def __getitem__(self, index):
	fn, label = self.imgs[index]   #通过index索引返回一个图像路径fn 与 标签label
	img = Image.open(fn).convert('RGB')  #把图像转成RGB
	if self.transform is not None:
		img = self.transform(img) 
	return img, label              #这就返回一个样本
 
def __len__(self):
	return len(self.imgs)          #返回长度,index就会自动的指导读取多少
 

Example 2: Reading through a tag file

#首先集成Dataset这个类
class DealDataset(Dataset):
    """
        下载数据、初始化数据,都可以在这里完成
    """
    def __init__(self):
 
        #这里xy 就是一个容器,通过读取一个包含有数据和标签信息的文件
        xy = np.loadtxt('../dataSet/diabetes.csv.gz', delimiter=',', dtype=np.float32)
 
        self.x_data = torch.from_numpy(xy[:, 0:-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
        
        #长度,可以给__len__返回用。
        self.len = xy.shape[0]
    
    def __getitem__(self, index):
        
        #通过索引index,索引到指定的数据以及对应的标签
        return self.x_data[index], self.y_data[index]
 
    def __len__(self):
        return self.len
 

Example 3: There is no label file, the code constructs itself according to the folder classification

class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform
 
    def __getitem__(self, index):
     
        path_img, label = self.data_info[index]       #索引读取图像路径和标签
        img = Image.open(path_img).convert('RGB')     # 读取图像,返回Image 类型 0~255
 
        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,把图像转为tensor等等
 
        return img, label
 
    def __len__(self):
        return len(self.data_info)
 
    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
 
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))
 
        return data_info  ##返回的也就是图像路径 和 标签

2. Folder read ImageFolder

# 预处理 转为tensor 以及 标准化
transform = transform.Compose([transform.ToTensor(), transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 
#使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹
traindata = torchvision.datasets.ImageFolder('data/rmb_split/train/', transform=transform)
trainloader = torch.utils.data.DataLoader(traindata, batch_size=4, shuffle=True, num_workers=1)
 
 
testset = torchvision.datasets.ImageFolder('data/rmb_split/test/', transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=1)
 

 

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=324472016&siteId=291194637