pytorch para la carga de datos del conjunto de datos

from torch.utils.data import Dataset,DataLoader
from torchvision import datasets,transforms
import os
from PIL import Image
import numpy as np
import torch

class My_dataset(Dataset):
    def __init__(self,root_path,is_train=True,is_miniTrain=False):
        self.is_train=is_train
        super().__init__()
        f=open(root_path,'r',encoding='utf-8')
        data_list=f.readlines()

        self.x=[]
        self.y=[]
        for i,data in enumerate(data_list):
            data=data.rstrip()
            self.x.append(data.split(',')[0])
            self.y.append(data.split(',')[1:])

        if is_miniTrain:
            self.x=self.x[:700]
            self.y=self.y[:700]

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        x=self.x[index]     #'./img/my_data/TRAIN/2586_paste.png'
        y=self.y[index]
        # img=Image.open('./img/my_data/TRAIN/2586_paste.png')
        # img.show()
        # exit()
        img=self.train_transform(Image.open(x)) if self.is_train \
            else self.others_transform(Image.open(x))       # #(224, 224, 3)

        lable=[]
        for i in y:
            lable.append(int(i))

        # lable=np.array(lable).reshape(5,-1)

        return img,lable

    def train_transform(self,x):
        return transforms.Compose([
            transforms.RandomCrop(224,padding=28),
            transforms.RandomRotation((0.5)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.486],std=[0.485,0.456,0.486])
        ])(x)

    def others_transform(self,x):
        return  transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.486],std=[0.485,0.456,0.486])
        ])(x)

Supongo que te gusta

Origin blog.csdn.net/qq_43586192/article/details/111536254
Recomendado
Clasificación