Pytorch模型训练之-- Pytorch数据集加载

Pytorch数据集加载

Dataset

如果弄明白了pytorchdataset类,你可以创建适应任意模型的数据集接口。

所谓数据集,无非就是一组{x:y}的集合,你只需要在这个类里说明“有一组=={x:y}==的集合”就可以了。

对于图像分类任务,图像+分类

对于目标检测任务,图像+bbox、分类

对于超分辨率任务,低分辨率图像+超分辨率图像

对于文本分类任务,文本+分类

可以通过.txt文件加载

/home/muzhan/projects/dataset/test/250_04.png _0
/home/muzhan/projects/dataset/test/250_05.png _7
/home/muzhan/projects/dataset/test/250_06.png _3
/home/muzhan/projects/dataset/test/250_07.png _2
/home/muzhan/projects/dataset/test/250_08.png _2
/home/muzhan/projects/dataset/test/250_09.png _3
/home/muzhan/projects/dataset/test/250_10.png _4
/home/muzhan/projects/dataset/test/250_11.png _0
/home/muzhan/projects/dataset/test/250_12.png _9

重新定义自己的dataset类

from torch.utils.data import Dataset
 
class MyDataSet(Dataset):
    def __init__(self, dataset_type, transform=None, update_dataset=False):
        """
        dataset_type: ['train', 'test']
        """
 
        dataset_path = '/home/muzhan/projects/dataset/'
 
        if update_dataset:
            make_txt_file(dataset_path)  # update datalist
 
        self.transform = transform
        self.sample_list = list()
        self.dataset_type = dataset_type
        f = open(dataset_path + self.dataset_type + '/datalist.txt')
        lines = f.readlines()
        for line in lines:
            self.sample_list.append(line.strip())
        f.close()
 
    def __getitem__(self, index):
        item = self.sample_list[index]
        # img = cv2.imread(item.split(' _')[0])
        img = Image.open(item.split(' _')[0])
        if self.transform is not None:
            img = self.transform(img)
        label = int(item.split(' _')[-1])
        return img, label
 
    def __len__(self):
        return len(self.sample_list)

Dataloader

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
dataset:定义的dataset类返回的结果。

batchsize:每个bacth要加载的样本数,默认为1。

shuffle:在每个epoch中对整个数据集data进行shuffle重排,默认为False。

sample:定义从数据集中加载数据所采用的策略,如果指定的话,shuffle必须为False;batch_sample类似,表示一次返回一个batch的index。

num_workers:表示开启多少个线程数去加载你的数据,默认为0,代表只使用主进程。

collate_fn:表示合并样本列表以形成小批量的Tensor对象。

pin_memory:表示要将load进来的数据是否要拷贝到pin_memory区中,其表示生成的Tensor数据是属于内存中的锁页内存区,这样将Tensor数据转义到GPU中速度就会快一些,默认为False。

drop_last:当你的整个数据长度不能够整除你的batchsize,选择是否要丢弃最后一个不完整的batch,默认为False。

enumerate()

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

遍历数据和标签

>>> import torch
>>> batch_data = torch.randn(10)
>>> batch_data
# tensor([-1.4227,  0.4803, -0.1308, -0.9972, -1.2646, -0.7575, -0.6185,  0.3919,
        -0.9820, -0.1905])
>>> labels = torch.linspace(1,10,10)
>>> labels
# tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
>>> for i, (data, labels) in enumerate(zip(batch_data, labels)):
        print("{}: {},{}".format(i, data, labels))
# 0: -1.4227263927459717,1.0
1: 0.48032230138778687,2.0
2: -0.13082626461982727,3.0
3: -0.9972370266914368,4.0
4: -1.2645894289016724,5.0
5: -0.7574924230575562,6.0
6: -0.6185144782066345,7.0
7: 0.39187055826187134,8.0
8: -0.9819689989089966,9.0
9: -0.19045710563659668,10.0

一般图片来说输入是
B x C x H x W 分别是 批量, 通道,高,宽

输出是

B x num_classes

训练过程中可用VISDOM进行可视化

注意充github源码进行安装不然可能失败

猜你喜欢

转载自blog.csdn.net/ahelloyou/article/details/114830706