Pytorch教程[02]DataLoader与Dataset

机器学习模型训练步骤

模型训练步骤

在这里插入图片描述

一.DataLoader

torch.utils.data.DataLoader()

功能:构建可迭代的数据装载器
• dataset: Dataset类,决定数据从哪读取
及如何读取
• batchsize : 批大小
• num_works: 是否多进程读取数据
• shuffle: 每个epoch是否乱序
• drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

DataLoader( dataset,
			batch_size=1,
			shuffle=False,
			sampler=None,
			batch_sampler=None,
			num_workers=0,
			collate_fn=None,
			pin_memory=False,
			drop_last=False,
			timeout=0,
			worker_init_fn=None,
			multiprocessing_context=None)

[Epoch、Epoch、Batch]三者之间的关系

  • Epoch: 所有训练样本都已输入到模型中,称为一个Epoch
  • Iteration:一批样本输入到模型中,称之为一个Iteration
  • Batchsize:批大小,决定一个Epoch有多少个Iteration

例:
样本总数:80, Batchsize:8
1 Epoch = 10 Iteration

样本总数:87, Batchsize:8
1 Epoch = 10 Iteration ? drop_last = True
1 Epoch = 11 Iteration ? drop_last = False

二、Dataset

torch.utils.data.Dataset()

功能:Dataset抽象类,所有自定义的
Dataset需要继承它,并且复写

__getitem__()
getitem #接收一个索引,返回一个样本
class Dataset(object):

	def __getitem__(self, index):
		raise NotImplementedError
		
	def __add__(self, other):
		return ConcatDataset([self, other])

在这里插入图片描述
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_43694096/article/details/123426869
今日推荐