Notas de estudio de Pytorch (11): conjunto de datos personalizados de pytorch

1. ¿Por qué utilizar la clase Conjuntos de datos?

  Datasets es una clase de pytorch. Pytorch viene con una variedad de conjuntos de datos, como: MINIST y otros conjuntos de datos se encuentran en la biblioteca de conjuntos de datos de pytorch.
  Hay una función de herramienta torch.utils.Data.DataLoader en Pytorch. A través de esta función, podemos usar el procesamiento paralelo de subprocesos múltiples cuando nos preparamos para cargar el conjunto de datos usando mini-batch, lo que puede acelerar nuestra preparación del conjunto de datos. Los conjuntos de datos son uno de los parámetros de instancia para crear esta función de utilidad.

2. ¿Cómo definir conjuntos de datos?

La clase Dataset es la clase más importante en Pytorch y también es la clase principal que debe heredarse en todas las clases de carga de conjuntos de datos en Pytorch. Las dos funciones miembro privadas en la clase principal deben sobrecargarse; de ​​lo contrario, se generará un mensaje de error:

def getitem(self, index):
def len(self):

Entre ellos, __len__ debería devolver el tamaño del conjunto de datos y __getitem__ debería escribir una función que admita el índice del conjunto de datos.
Aquí nos centramos en la función getitem. getitem recibe un índice y luego devuelve los datos y las etiquetas de la imagen. Este índice generalmente se refiere a un índice de lista, cada elemento de esta lista contiene la ruta y la información de la etiqueta de los datos de la imagen.

Tres, combate real

La composición del conjunto de datos.
inserte la descripción de la imagen aquí

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from torchvision import transforms

class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.layer1=nn.Sequential(
            nn.Linear(3,20),
            nn.Sigmoid(),
            nn.Linear(20,40),
            nn.Sigmoid(),
            nn.Linear(40,1)
        )
    def forward(self,x):
        data=x
        data=self.layer1(data)
        return data

class MyDataset(Dataset):
    def __init__(self,root,transform=None):
        super(MyDataset,self).__init__()
        #读取数据,整理读取的x值为一列
        df=pd.read_csv(root,dtype=np.float32)
        #self.data=pd.DataFrame(columns=['data','label'])
        data=[] #用于获取3个x值并组合为一列
        label=[] #用于获取标签值
        self.data=[]
        self.label=[]
        for i in range(df.shape[0]):
            x=df.loc[i] #type:Series
            data.append([x['x1'],x['x2'],x['x3']])
            label.append(x['y'])
        #self.data['data']=data
        #self.data['label']=label
        self.data=data
        self.label=label
        self.transform=transform



    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        x=self.data[item]
        label=self.label[item]
        if self.transform is not None:
            x=self.transform(x)
        return x,label

class ToTensor(object):
    def __call__(self, seq):
        #print(seq.shape)
        return torch.tensor(seq,dtype=torch.float)

if __name__=='__main__':
    path = 'C:/Users/Mr.Li\Desktop/test project/train.csv'
    set=MyDataset(path,ToTensor())
    data=torch.utils.data.DataLoader(dataset=set,batch_size=6,shuffle=True)
    model=Model()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    loss_func = torch.nn.MSELoss()

    for epoch in range(100):
        for i,( x,label) in enumerate(data):
            y=model(x)
            z=label.view(-1,1)
            loss = loss_func(y, z)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(loss)






Supongo que te gusta

Origin blog.csdn.net/qq_23345187/article/details/123164323
Recomendado
Clasificación