PyTorch中torch.utils.data.DataLoader加载数据

torch.utils.data.DataLoader使用方法

       DataLoader是PyTorch中的一种数据类型,在PyTorch架构中训练或者验证模型经常要使用它,那么怎么生成以及使用这样的数据类型?
在这里插入图片描述


一、参数设置

torch.utils.data.DataLoader(
      dataset   			#数据加载
      batch_size = 1		#批处理样本大小
      shuffle = False		#是否在每一轮epoch打乱样本顺序
      sampler = None		#指定数据加载中使用的索引/键的序列
      batch_sampler = None	#和sampler类似
      num_workers = 0		#是否进行多进程加载数据设置
      collat​​e_fn = None		#是否合并样本列表以形成一小批Tensor
      pin_memory = False	#如果True,数据加载器会在返回之前将Tensors复制到CUDA固定内存
      drop_last = False		#True若数据集大小不能被batch_size整除,则删除最后一个不完整的批处理。
      timeout = 0			#如果为正,则为从工作人员收集批处理的超时值
      worker_init_fn = None )

       具体可参考官方文档

1、dataset:(数据类型 Dataset)
       输入的数据类型,也是最重要的参数,它表示要加载数据的数据集对象。

2、batch_size:(数据类型 int)
       批处理样本的大小,默认为1。

3、shuffle:(数据类型 bool)
       在每轮迭代训练时是否将数据洗牌。默认设置为False。设置为True则是在每一轮中,输入数据的顺序将被打乱,这是为了使数据更有独立性,训练的时候一般都设置为True,若输入数据是有序的,就不要设置成True了。

4、collate_fn:(数据类型 callable可调用对象)
       将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。

5、sampler:(数据类型 Sampler)
       采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。

6、num_workers:(数据类型 Int)
       子进程数量,默认是0。使用多少个子进程来加载数据。0 就是使用主进程来加载数据。注意:这个数字必须是大于等于0的,该值的设置应该量内存大小而为

7、pin_memory:(数据类型 bool)
       内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。

扫描二维码关注公众号,回复: 11427359 查看本文章

8、drop_last:(数据类型 bool)
       丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。

9、timeout:(数据类型 numeric)
       超时值,默认为0。是用来设置数据读取的超时时间,超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。


二、实际应用

import torch
from torch.utils.data import Dataset, DataLoader

#---------------预处理-----------------
transform = transforms.Compose([
    	transforms.Resize((224, 224), 2),
    	transforms.ToTensor(),
    	transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
#--------------数据加载----------------
trainset = torchvision.datasets.CIFAR10(root='./data', 
										train=True, 
										download=False, 
										transform=transform)
# torch.utils.data.DataLoader
trainloader = DataLoader(dataset=trainset, 
							batch_size=32, 
							shuffle=True, 
							num_workers=0)  

for epoch in range(100):
	running_loss = 0.0
	batch_size = 32
	for i, data in enumerate(trainloader, 0):
    	inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)

猜你喜欢

转载自blog.csdn.net/qq_40520596/article/details/106981039