pytorch学习1:如何加载自己的训练数据


Pytorch中文文档已出(http://pytorch-cn.readthedocs.io/zh/latest/)。第一篇博客献给了pytorch,主要是为了整理自己的思路。

原来使用caffe,总是要编译,经历了无数的坑。当开始接触pytorch时,果断拔草caffe。

学习Pytorch最好有一些深度学习理论基础才更好开,废话不多说,进入主题。

1 先有个框框,再往里面填东西

当训练一个神经网络的时候,我们需要有数据,有模型,并且需要设置训练的参数。为了不乱,我们最好分别定义三个文件,分别是:数据准备和预处理traindataset.py+编写模型model.py+如何训练main.py(xx.py,xx自己可任意取名)。

今天我们只讲数据准备与预处理阶段:traindataset.py(怎样命名无所谓,as you like)。这个文件的作用是什么呢?

统一将图像(或矩阵)返回成torch能处理的[original_iamges.tensor,label.tensor]

我们先跳跃一下看中文介绍是如何导入数据:

 torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
我们一般关注DataLoader四个参数:
dataset, batch_size, shuffle, num_workers=0
batch_size是你批处理数目,shuffle是否每个epoch都打乱,workers是载入数据的线程数(请查看中文文档对每个参数的解释)

我们具体看看“dataset”——加载数据的数据集。

这个dataset应该是[original_iamges.tensor,label.tensor]之类的,我们定义的“traindataset.py”就是产生这个dataset的。


你只需在main.py 文件import就可调用!

from traindataset import *

2 定义一个py文件产生我们自己的dataset

这个py文件一定要1:能输入我自己的数据路径 2:还得预处理吧,比如的裁剪啊~
step 1:先导入你肯定需要的库路径
import torch.utils.data 
import torch
from tochvision import transforms
t orch.utils.data模块是子类化你的数据
transforms库对数据预处理
step 2:自定义dataset类(子类化你的数据)
class MyTrainData(torch.utils.data.Dataset)
这里继承了torch.utils.data.Dataset这个类,我们看看这个类在中文文档中介绍:
所有其他数据集都应该进行子类化。所有子类应该override__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
当然还有个初始化__init__()
ok,我们脸谱化py文件,再往里面加东西(以下为基础框架):
 
  
#encoding:utf-8
import torch.utils.data as data
import torch
from torchvision import transforms

class trainmydatalala(torch.utils.data.Dataset) #子类化

  def __init__(self, root, transform=None, train=True): #第一步初始化

    self.root = root   
    self.train = train
  def __getitem__(self, idx): #第二步装载数据,返回[img,label]


      img = imread(img_path)
      img = torch.from_numpy(img).float()

      gt = imread(gt_path)    
      gt = torch.from_numpy(gt).float()

      return img, gt 

  def __len__(self):
    return len(self.imagenumber)

现在往框框里面填:(1)是否transform如裁剪、归一化、旋转等?(2)是否区分test和train?(3)如何做到一张一张对应读取图片?
Python读取图片需要scipy库,要想批处理读取,还需要os库;
以下贴出完整代码:
#encoding:utf-8
import torch.utils.data as data
import torch

from scipy.ndimage import imread
import os
import os.path
import glob

from torchvision import transforms

def make_dataset(root, train=True):

  dataset = []

  if train:
    dirgt = os.path.join(root, 'train_data/groundtruth') 
    dirimg = os.path.join(root, 'train_data/imgs')

    for fGT in glob.glob(os.path.join(dirgt, '*.jpg')):
    # for k in range(45)
      fName = os.path.basename(fGT)    
      fImg = 'train_ori'+fName[8:]
      dataset.append( [os.path.join(dirimg, fImg), os.path.join(dirgt, fName)] )

  return dataset

#自定義dataset的框架
class kaggle2016nerve1(data.Dataset):   #需要繼承data.Dataset

  def __init__(self, root, transform=None, train=True): #初始化文件路進或文件名
    self.train = train
    if self.train:
      self.train_set_path = make_dataset(root, train)

  def __getitem__(self, idx):
    if self.train:
      img_path, gt_path = self.train_set_path[idx]

      img = imread(img_path)
      img = np.atleast_3d(img).transpose(2, 0, 1).astype(np.float32)
      img = (img - img.min()) / (img.max() - img.min())
      img = torch.from_numpy(img).float()

      gt = imread(gt_path)
      gt = np.atleast_3d(gt).transpose(2, 0, 1)
      gt = gt / 255.0
      gt = torch.from_numpy(gt).float()

      return img, gt  

  def __len__(self):

    return len(self.train_set_path)
   
这里的py文件需要在最后main.py文件中调用,所以root我并没有赋值,我会在main,py中赋值。
这里我并没有用到“transform”进行预处理,如果你想用的话,在__getitem__()下面,return img,gt前重新赋值
img = transforms.ToTensor(img)以及gt = transforms.ToTensor(gt)
这需要注意的是,查看中文文档transforms库有哪些变换,如果有需要涉及参数的如CenterCrop(size),需要先实参化,如
crop = transforms.CenterCrop(10);再使用:img = crop(img)

猜你喜欢

转载自blog.csdn.net/woshicao11/article/details/78318156