[Pytorch] Función DataLoader

El conjunto de datos ( https://blog.csdn.net/TH_NUM/article/details/80877196 ) solo es responsable de la abstracción de datos, y solo una muestra devuelve una llamada a getitem . Como se mencionó anteriormente, cuando se entrena una red neuronal, es mejor operar en un lote de datos, y también es necesario realizar una aceleración aleatoria y paralela en los datos. En este sentido, PyTorch proporciona DataLoader para ayudarnos a lograr estas funciones.

Las funciones de DataLoader se definen de la siguiente manera:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
num_workers=0, collate_fn=default_collate, pin_memory=False, 
drop_last=False)

conjunto de datos : conjunto de datos cargados (objeto del conjunto de datos) 
tamaño_bits : tamaño de lote 
aleatorio : si se debe alterar el
muestreador de datos  : muestreo de muestra,
num_workers se describirá en detalle  más adelante : el número de procesos cargados usando multiproceso, 0 significa no usar multiproceso 
collate_fn : cómo Se empalman múltiples datos de muestra en un lote, por lo general, el método de empalme predeterminado se puede usar 
pin_memory : si se guardan los datos en el área de memoria del pin, los datos en la memoria del pin serán más rápidos de transferir a la GPU 
drop_last : el número de datos en el conjunto de datos puede no ser batch_size Múltiple entero de, drop_last es True descartará más de un lote de datos

from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms
from torch.utils.data import DataLoader

#加上transforms
normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
transform=transforms.Compose([
    transforms.RandomSizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
    normalize
])

dataset=ImageFolder('data/dogcat_2/',transform=transform)

#dataloader是一个可迭代的对象,意味着我们可以像使用迭代器一样使用它 或者 or batch_datas, batch_labels in dataloader:
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)

dataiter = iter(dataloader)
imgs, labels = next(dataiter)
print(imgs.size()) # batch_size, channel, height, weight
#输出 torch.Size([3, 3, 224, 224])

En el procesamiento de datos, a veces no se puede leer una muestra, como una imagen dañada. En este momento  , aparecerá una excepción en la  función _getitem _, y la mejor solución en este momento es eliminar la muestra incorrecta. Si realmente no puede lidiar con esta situación, puede volver al objeto Ninguno y luego implementar un collate_fn personalizado en el Cargador de datos para filtrar los objetos vacíos. Pero tenga en cuenta que, en este caso, el número de lotes devueltos por el cargador de datos será menor que batch_size.

'''
在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。
'''
from dataSet import *
import random
class NewDogCat(DogCat): # 继承前面实现的DogCat数据集
    def __getitem__(self, index):
        try:
            # 调用父类的获取函数,即 DogCat.__getitem__(self, index)
            return super(NewDogCat,self).__getitem__(index)
        except:
            #对于诸如样本损坏或数据集加载异常等情况,还可以通过其它方式解决。例如但凡遇到异常情况,就随机取一张图片代替:
            new_index = random.randint(0, len(self) - 1)
            return self[new_index]

from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
from torch.utils.data import DataLoader
def my_collate_fn(batch):
    '''
    batch中每个元素形如(data, label)
    '''
    # 过滤为None的数据
    batch = list(filter(lambda x:x[0] is not None, batch))
    if len(batch) == 0: return torch.Tensor()
    return default_collate(batch) # 用默认方式拼接过滤后的batch数据


transform=transforms.Compose([
    transforms.Resize(224), #缩放图片,保持长宽比不变,最短边的长为224像素,
    transforms.CenterCrop(224), #从中间切出 224*224的图片
    transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
    transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #标准化至[-1,1]
])


dataset = NewDogCat(root='data/dogcat_wrong/', transform=transform)

#print(dataSet[11])
dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=1,shuffle=True)
for batch_datas, batch_labels in dataloader:
    print(batch_datas.size(),batch_labels.size())

Dirección de ithub: https://github.com/WebLearning17/CommonTool

Referencia: https://github.com/chenyuntc/pytorch-book/blob/master/chapter5-%E5%B8%B8%E7%94%A8%E5%B7%A5%E5%85%B7/chapter5.ipynb

190 artículos originales publicados · elogiados 497 · 2.60 millones de visitas +

Supongo que te gusta

Origin blog.csdn.net/u013066730/article/details/104773020
Recomendado
Clasificación