pytorch 数据加载的2种方法

一. 封装成类Dataset,再用加载器Dataloader

1.封装成类Dataset:

数据集合转化成Dataset这个类,然后必须有 
__init__来加载数据集, 
__len__来获取数据集的数据数量,用于for循环的次数, 
__getitem__来索引数据集中的某条数据

1.

import torch
import os
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

class GetData(Dataset):
    def __init__(self,path0,path1): #得到名字list
        super(GetData,self).__init__()
        self.path0 = path0
        self.path1 = path1
        self.name0_list = os.listdir(self.path0)
        self.name1_list = os.listdir(self.path1)
        self.img2data = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.name0_list)

    def __getitem__(self, index): #按名取图,index对应批次
        self.name0 = self.name0_list[index]
        self.name1 = self.name1_list[index]
        img0 = Image.open(os.path.join(self.path0, self.name0))
        img1 = Image.open(os.path.join(self.path1, self.name1))
        imgdata0 = self.img2data(img0)
        imgdata1 = self.img2data(img1)

        return imgdata0, imgdata1




class Trainer(nn.Module):
    def __init__(self):
        super(Trainer,self).__init__()

        self.main_net = MainNet()
        self.main_net.cuda()#主网络加cuda,就相当于里面的所有网络加了cuda

        '涉及2种损失,自然就会有对应2个优化器做反向传播'
        vae_parameters = []
        vae_parameters.extend(self.main_net.encoder.parameters())
        vae_parameters.extend(self.main_net.decoder.parameters())

        self.opt_dis = torch.optim.Adam(self.main_net.discriminator.parameters(), lr=1e-3)
        self.opt_vae = torch.optim.Adam(vae_parameters, lr=1e-3)

    def train(self):
        for epoch in range(10000):

            if os.path.exists('encoder.pkl'):
                self.main_net.encoder.load_state_dict(torch.load('encoder.pkl'))
            if os.path.exists('decoder.pkl'):
                self.main_net.decoder.load_state_dict(torch.load('decoder.pkl'))
            if os.path.exists('discriminator.pkl'):
                self.main_net.discriminator.load_state_dict(torch.load('discriminator.pkl'))

            self.dataloader = DataLoader(dataset.GetData(path0=r'C:\Users\87419\Desktop\cg1\64',
                             path1=r'C:\Users\87419\Desktop\cg1\dama_64'), batch_size=128, shuffle=True)
            count = 0

            '每个epoch内都是遍历5万张图,即dataloader数。每count一次,即每次循环都是处理batchsize张'
            'dataloader长度 = 总张数/批次数 :782 = 50000/64。即loader长度等于每个ecpoch的总count数'
            for img0data, img1data in self.dataloader:

                img0data = img0data.cuda()#把输入的数据加cuda,接下来里面的过程数据自然也就以cuda运行
                img1data = img1data.cuda()

                count += 1
                print(count)

2.


import torch
import os
import numpy as np
import cv2
from torch.utils.data import Dataset,DataLoader
 
class GetData(Dataset):
    def __init__(self,path1,path2):
        super(GetData,self).__init__()
        self.path1 = path1
        self.path2 = path2
        self.dataset1 = []
        self.dataset2 = []
        self.dataset1.extend(open(os.path.join(self.path1,'label.txt')).readlines())
        self.dataset2.extend(open(os.path.join(self.path2,'label.txt')).readlines())
 
    def __getitem__(self, index): #index不是待赋参量,而是对应批次batch_size
        str1 = self.dataset1[index].strip() #如dataset[0]是第一批次
        str2 = self.dataset2[index].strip()
        imgpath1 = os.path.join(self.path1,str1)
        imgpath2 = os.path.join(self.path2,str2)
        im1 = cv2.imread(imgpath1)
        im2 = cv2.imread(imgpath2)
        imgdata1 = torch.Tensor((im1 / 255. - 0.5))
        imgdata2 = torch.Tensor((im2 / 255. - 0.5))
        return imgdata1,imgdata2
 
    def __len__(self):
        return len(self.dataset1)

2.再用加载器Dataloader

for i in range(EPOCHES):
    print('epoch:',i)
    dataset = GetData(r'C:\Users\87419\Desktop\VAE1\data\trainB', r'C:\Users\87419\Desktop\VAE1\data\trainA')
 
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    for j,(imgdata1,imgdata2) in enumerate(dataloader):
      
        imgdata1_ = imgdata1.cuda()
        imgdata2_ = imgdata2.cuda()

二. 用torchviosion

import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

transform1 = transforms.Compose([
        transforms.Resize(256),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ])
dataset = datasets.ImageFolder(root=r'C:\Users\87419\Desktop\VAE\faces',transform=transform1)
dataset_loader = DataLoader(dataset,batch_size=4, shuffle=True)

或定义数据加载函数:

def sample_data(path,batch_size,size=4):
    transform = transforms.Compose([transforms.Resize(size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
    dataset = datasets.ImageFolder(path,transform=transform)
    dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)

    return dataloader

猜你喜欢

转载自blog.csdn.net/qq_39938666/article/details/88429064