【pytorch】自定义读取数据集,使用txt文本

使用txt文本读入数据可以减少内存的需要,有时候自定义加载数据集是非常必要的,我下面的代码是针对图像的,并且带有label的有监督的图像。先看代码:

import numpy as np
import os
import torch.nn as nn
from PIL import Image


def default_loader(path):
    return Image.open(path).convert('RGB')


class MyDataSet(nn.Module):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        super(MyDataSet, self).__init__()
        fh = open(txt, 'r')
        imgs = []
        target = []
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0]))
            target.append(words[1])
        self.imgs = imgs
        self.target = target
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        fh.close()

    def __getitem__(self, index):
        fn= self.imgs[index]
        target=self.target[index]
        img = self.loader(fn)
        target=self.loader(target)
        img=np.asarray(img)
        target=np.asarray(target)
        # img = Image.fromarray(np.array(img), mode='L')
        # target = Image.fromarray(np.array(target), mode='L')

        if self.transform is not None:
            img = self.transform(img)
            target=self.transform(target)
        return img, target

    def __len__(self):
        return len(self.imgs)


def TxtData(canon_path, iphone_path, txt_path):
    txt_name = 'DPED.txt'
    txt_path = os.path.join(txt_path, txt_name)
    canon = os.listdir(canon_path)
    iphone = os.listdir(iphone_path)
    file = open(txt_path, 'w')

    if len(canon) != len(iphone):
        print('the number of traing data varies between the two files')

    for i in range(len(canon)):
        path = os.path.join(iphone_path, str(i) + '.jpg') + ' ' + os.path.join(canon_path, str(i) + '.jpg') + '\n'
        file.write(path)


if __name__ == '__main__':
    data_canon = './data/iphone/training_data/canon'
    data_iphone = './data/iphone/training_data/iphone'
    txt_path = './data'
    TxtData(data_canon, data_iphone, txt_path)

上面代码主要分两部分:

1:图像对分别在两个不同的文件夹,文件路径可自己更改,先运行代码会生成一个.txt文件,里面包括了所以图像的文件路径。txt文如下:

2:定义自己的加载数据的子类。每当使用数据时,也就是数据索引[ ]时会自己调用 def  __getitem__函数。这部分是在训练阶段 才会进行。

猜你喜欢

转载自blog.csdn.net/lyl771857509/article/details/84642560
今日推荐