Pytorch学习笔记(十 一)——pytorch自定义数据集

一、为什么要使用Datasets类

  Datasets是pytorch的一个类,pytorch自带多种数据集,如:MINIST等数据集就是在pytorch的Datasets的库中的。
  Pytorch中有工具函数torch.utils.Data.DataLoader,通过这个函数我们在准备加载数据集使用mini-batch的时候可以使用多线程并行处理,这样可以加快我们准备数据集的速度。Datasets就是构建这个工具函数的实例参数之一。

二、如何定义Datasets?

Dataset类是Pytorch中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示:

def getitem(self, index):
def len(self):

其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数
这里重点看 getitem函数,getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。

三、实战

数据集的内容组成
在这里插入图片描述

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from torchvision import transforms

class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.layer1=nn.Sequential(
            nn.Linear(3,20),
            nn.Sigmoid(),
            nn.Linear(20,40),
            nn.Sigmoid(),
            nn.Linear(40,1)
        )
    def forward(self,x):
        data=x
        data=self.layer1(data)
        return data

class MyDataset(Dataset):
    def __init__(self,root,transform=None):
        super(MyDataset,self).__init__()
        #读取数据,整理读取的x值为一列
        df=pd.read_csv(root,dtype=np.float32)
        #self.data=pd.DataFrame(columns=['data','label'])
        data=[] #用于获取3个x值并组合为一列
        label=[] #用于获取标签值
        self.data=[]
        self.label=[]
        for i in range(df.shape[0]):
            x=df.loc[i] #type:Series
            data.append([x['x1'],x['x2'],x['x3']])
            label.append(x['y'])
        #self.data['data']=data
        #self.data['label']=label
        self.data=data
        self.label=label
        self.transform=transform



    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        x=self.data[item]
        label=self.label[item]
        if self.transform is not None:
            x=self.transform(x)
        return x,label

class ToTensor(object):
    def __call__(self, seq):
        #print(seq.shape)
        return torch.tensor(seq,dtype=torch.float)

if __name__=='__main__':
    path = 'C:/Users/Mr.Li\Desktop/test project/train.csv'
    set=MyDataset(path,ToTensor())
    data=torch.utils.data.DataLoader(dataset=set,batch_size=6,shuffle=True)
    model=Model()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    loss_func = torch.nn.MSELoss()

    for epoch in range(100):
        for i,( x,label) in enumerate(data):
            y=model(x)
            z=label.view(-1,1)
            loss = loss_func(y, z)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(loss)






猜你喜欢

转载自blog.csdn.net/qq_23345187/article/details/123164323