Pytorch study notes (11) - pytorch custom data set

1. Why use the Datasets class

  Datasets is a class of pytorch. Pytorch comes with a variety of data sets, such as: MINIST and other data sets are in the library of pytorch's Datasets.
  There is a tool function torch.utils.Data.DataLoader in Pytorch. Through this function, we can use multi-threaded parallel processing when preparing to load the dataset using mini-batch, which can speed up our preparation of the dataset. Datasets is one of the instance parameters to build this utility function.

2. How to define Datasets?

The Dataset class is the most important class in Pytorch, and it is also the parent class that should be inherited in all dataset loading classes in Pytorch. The two private member functions in the parent class must be overloaded, otherwise an error message will be triggered:

def getitem(self, index):
def len(self):

Among them, __len__ should return the size of the data set, and __getitem__ should write a function that supports the index of the data set.
Here we focus on the getitem function. getitem receives an index, and then returns the image data and labels. This index usually refers to a list. index, each element of this list contains the path and label information of the image data.

Three, actual combat

The composition of the data set
insert image description here

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from torchvision import transforms

class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.layer1=nn.Sequential(
            nn.Linear(3,20),
            nn.Sigmoid(),
            nn.Linear(20,40),
            nn.Sigmoid(),
            nn.Linear(40,1)
        )
    def forward(self,x):
        data=x
        data=self.layer1(data)
        return data

class MyDataset(Dataset):
    def __init__(self,root,transform=None):
        super(MyDataset,self).__init__()
        #读取数据,整理读取的x值为一列
        df=pd.read_csv(root,dtype=np.float32)
        #self.data=pd.DataFrame(columns=['data','label'])
        data=[] #用于获取3个x值并组合为一列
        label=[] #用于获取标签值
        self.data=[]
        self.label=[]
        for i in range(df.shape[0]):
            x=df.loc[i] #type:Series
            data.append([x['x1'],x['x2'],x['x3']])
            label.append(x['y'])
        #self.data['data']=data
        #self.data['label']=label
        self.data=data
        self.label=label
        self.transform=transform



    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        x=self.data[item]
        label=self.label[item]
        if self.transform is not None:
            x=self.transform(x)
        return x,label

class ToTensor(object):
    def __call__(self, seq):
        #print(seq.shape)
        return torch.tensor(seq,dtype=torch.float)

if __name__=='__main__':
    path = 'C:/Users/Mr.Li\Desktop/test project/train.csv'
    set=MyDataset(path,ToTensor())
    data=torch.utils.data.DataLoader(dataset=set,batch_size=6,shuffle=True)
    model=Model()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    loss_func = torch.nn.MSELoss()

    for epoch in range(100):
        for i,( x,label) in enumerate(data):
            y=model(x)
            z=label.view(-1,1)
            loss = loss_func(y, z)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(loss)






Guess you like

Origin blog.csdn.net/qq_23345187/article/details/123164323