使用txt文本读入数据可以减少内存的需要,有时候自定义加载数据集是非常必要的,我下面的代码是针对图像的,并且带有label的有监督的图像。先看代码:
import numpy as np
import os
import torch.nn as nn
from PIL import Image
def default_loader(path):
return Image.open(path).convert('RGB')
class MyDataSet(nn.Module):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
super(MyDataSet, self).__init__()
fh = open(txt, 'r')
imgs = []
target = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0]))
target.append(words[1])
self.imgs = imgs
self.target = target
self.transform = transform
self.target_transform = target_transform
self.loader = loader
fh.close()
def __getitem__(self, index):
fn= self.imgs[index]
target=self.target[index]
img = self.loader(fn)
target=self.loader(target)
img=np.asarray(img)
target=np.asarray(target)
# img = Image.fromarray(np.array(img), mode='L')
# target = Image.fromarray(np.array(target), mode='L')
if self.transform is not None:
img = self.transform(img)
target=self.transform(target)
return img, target
def __len__(self):
return len(self.imgs)
def TxtData(canon_path, iphone_path, txt_path):
txt_name = 'DPED.txt'
txt_path = os.path.join(txt_path, txt_name)
canon = os.listdir(canon_path)
iphone = os.listdir(iphone_path)
file = open(txt_path, 'w')
if len(canon) != len(iphone):
print('the number of traing data varies between the two files')
for i in range(len(canon)):
path = os.path.join(iphone_path, str(i) + '.jpg') + ' ' + os.path.join(canon_path, str(i) + '.jpg') + '\n'
file.write(path)
if __name__ == '__main__':
data_canon = './data/iphone/training_data/canon'
data_iphone = './data/iphone/training_data/iphone'
txt_path = './data'
TxtData(data_canon, data_iphone, txt_path)
上面代码主要分两部分:
1:图像对分别在两个不同的文件夹,文件路径可自己更改,先运行代码会生成一个.txt文件,里面包括了所以图像的文件路径。txt文如下:
2:定义自己的加载数据的子类。每当使用数据时,也就是数据索引[ ]时会自己调用 def __getitem__函数。这部分是在训练阶段 才会进行。