PyTorch lee imágenes principalmente a través de la clase Dataset
, así que primero entendamos brevemente la
clase
Dataset .
conjunto de datos
La clase existe como la clase base de todos
los conjuntos de datos
, y todos los
conjuntos de datos
deben heredarla, de forma similar a
la base virtual en
C++
amable.
Aquí nos enfocamos en la
función
getitem ,
getitem recibe un índice y luego devuelve los datos de la imagen y las etiquetas, esto
índice
generalmente se refiere al
índice de una
lista
, y cada elemento de esta lista contiene la ruta y la información de la etiqueta de los datos de la imagen
aliento.
Sin embargo, cómo hacer esta
lista
, el método habitual es almacenar la ruta y etiquetar la información de la imagen en un
txt
, y luego
lea de ese
txt .
Entonces el proceso básico de leer sus propios datos es:
1. Cree
un txt
que almacene la ruta y la información de la etiqueta de la imagen
2.
Convierta esta información en
una lista
, y
cada elemento de la
lista corresponde a una muestra
3.
A través de la función
getitem
, lea los datos y las etiquetas y devuelva los datos y las etiquetas
Por lo tanto, para permitir que
PyTorch
lea su propio conjunto de datos, solo se requieren dos pasos:
1.
Haz un índice de datos de imágenes
2.
Cree una
subclase
de conjunto de datos
1. Generar código de Bloc de notas
import os
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径
'''
为数据集生成对应的txt文件
'''
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径
train_txt_path = os.path.join(base_dir, "Data", "train.txt")
train_dir = os.path.join(base_dir, "Data", "train")
valid_txt_path = os.path.join(base_dir, "Data", "valid.txt")
valid_dir = os.path.join(base_dir, "Data", "valid")
print(train_txt_path)
print(train_dir)
print(valid_txt_path)
print(valid_dir)
def gen_txt(txt_path, img_dir):
f = open(txt_path, 'w')
for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称
for sub_dir in s_dirs:
i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
for i in range(len(img_list)):
if not img_list[i].endswith('png'): # 若不是png文件,跳过
continue
label = img_list[i].split('_')[0]
img_path = os.path.join(i_dir, img_list[i])
line = img_path + ' ' + label + '\n'
f.write(line)
f.close()
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)
2. Efecto
3. Código de clase del conjunto de datos
class MyDataset(Dataset):
def __init__(self, txt_path, transform=None, target_transform=None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip() #rstrip函数返回字符串副本,该副本是从字符串最右边删除了参数指定字符后的字符串,不带参数进去则是去除最右边的空格
words = line.split() #默认以空格为分隔符
imgs.append((words[0], int(words[1])))
self.imgs = imgs # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
# transform 是一个 Compose 类型,里边有一个 list,list 中就会定义了各种对图像进行处理的操作,
#可以设置减均值,除标准差,随机裁剪,旋转,翻转,仿射变换等操作
#在这里我们可以知道,一张图片读取进来之后,会经过数据处理(数据增强),
#最终变成输入模型的数据。这里就有一点需要注意,PyTorch 的数据增强是将原始图片进行了处理
#并不会生成新的一份图片,而是“覆盖”原图
self.target_transform = target_transform
self.transform = transform
def __getitem__(self, index):
fn, label = self.imgs[index]
#对图片进行读取
img = Image.open(fn).convert('RGB') # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.imgs)
4.carga de datos
Cuando se crea
Mydataset
, las operaciones restantes se transferirán a
DataLoder
. En
DataLoder
, se activará
La función getiterm
en
Mydataset lee los datos y las etiquetas de una imagen y los une en un lote para devolverlos, como
La entrada real del modelo.