PyTorch reads pictures mainly through the Dataset class
, so let's briefly understand
the Dataset
class first.
Dataset
The class exists as the base class of all
datasets
, and all
datasets
need to inherit it, similar to
the virtual base in
C++
kind.
Here we focus on the
getitem
function,
getitem
receives an
index
, and then returns the image data and labels, this
index
usually refers to
the index of a
list
, and each element of this list contains the path and label information of the image data
breath.
However, how to make this
list
, the usual method is to store the path and label information of the picture in a
txt
, and then
read from that
txt .
Then the basic process of reading your own data is:
1. Make
a txt
that stores the path and label information of the image
2.
Convert this information into
a list
, and
each element of the
list corresponds to a sample
3.
Through the
getitem
function, read the data and tags, and return the data and tags
Therefore, to enable
PyTorch
to read its own dataset, only two steps are required:
1.
Make an index of image data
2.
Build a
Dataset
subclass
1. Generate Notepad code
import os
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径
'''
为数据集生成对应的txt文件
'''
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径
train_txt_path = os.path.join(base_dir, "Data", "train.txt")
train_dir = os.path.join(base_dir, "Data", "train")
valid_txt_path = os.path.join(base_dir, "Data", "valid.txt")
valid_dir = os.path.join(base_dir, "Data", "valid")
print(train_txt_path)
print(train_dir)
print(valid_txt_path)
print(valid_dir)
def gen_txt(txt_path, img_dir):
f = open(txt_path, 'w')
for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称
for sub_dir in s_dirs:
i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
for i in range(len(img_list)):
if not img_list[i].endswith('png'): # 若不是png文件,跳过
continue
label = img_list[i].split('_')[0]
img_path = os.path.join(i_dir, img_list[i])
line = img_path + ' ' + label + '\n'
f.write(line)
f.close()
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)
2. Effect
3. Dataset class code
class MyDataset(Dataset):
def __init__(self, txt_path, transform=None, target_transform=None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip() #rstrip函数返回字符串副本,该副本是从字符串最右边删除了参数指定字符后的字符串,不带参数进去则是去除最右边的空格
words = line.split() #默认以空格为分隔符
imgs.append((words[0], int(words[1])))
self.imgs = imgs # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
# transform 是一个 Compose 类型,里边有一个 list,list 中就会定义了各种对图像进行处理的操作,
#可以设置减均值,除标准差,随机裁剪,旋转,翻转,仿射变换等操作
#在这里我们可以知道,一张图片读取进来之后,会经过数据处理(数据增强),
#最终变成输入模型的数据。这里就有一点需要注意,PyTorch 的数据增强是将原始图片进行了处理
#并不会生成新的一份图片,而是“覆盖”原图
self.target_transform = target_transform
self.transform = transform
def __getitem__(self, index):
fn, label = self.imgs[index]
#对图片进行读取
img = Image.open(fn).convert('RGB') # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.imgs)
4.dataload
When
Mydataset
is built, the rest of the operations will be handed over to
DataLoder
. In
DataLoder
, it will trigger
The getiterm function
in
Mydataset reads the data and labels of a picture, and stitches them into a batch to return, as
The real input of the model.