API_Net官方代码之数据处理

在这里插入图片描述

一、数据准备

总结:
RandomDataset :用于验证 (val)
BatchDataset:用于训练 (train)
BalancedBatchSampler:决定如何采样样本,不是简单的在Dataloader中设置一个batch_size了

1)导入的包类:

import torch
from PIL import  Image
import numpy as np
from torchvision import transforms

from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler#此处的BatchSampler相当于在Dataloader中设置的batch_size

2)读取图片函数

学习点:考虑读取图片失败情况,使用try…except结构,并将读取失败的图路径存储下来,并返回一个全为白色的同尺度新图。

def default_loader(path):
    '''
    打开图片并转化为RGB,打开失败,则记录下来,并返回一个新图
    '''
    try:
        img = Image.open(path).convert('RGB')
    except:
        with open('read_error.txt', 'a') as fid:
            fid.write(path+'\n')
        return Image.new('RGB', (224,224), 'white')
    return img

3)RandomDataset类

此类用于选取指定index的样本,返回的是一张图片以及对应的标签。
批量获取是在Dataloader中设置的。

class RandomDataset(Dataset):
    def __init__(self, transform=None, dataloader=default_loader):#此处dataloader用不上
        self.transform = transform
        self.dataloader = dataloader

        with open('val.txt', 'r') as fid:#将图片路径以及标签读取出来
            self.imglist = fid.readlines()

    def __getitem__(self, index):
        image_name, label = self.imglist[index].strip().split() #获取对应的路径以及标签
        image_path = image_name
        img = self.dataloader(image_path)
        img = self.transform(img)
        label = int(label) #特别注意,要将label设置为int类型
        label = torch.LongTensor([label])

        return [img, label] #注意官网推荐使用字典{'image':img,'label':label}


    def __len__(self):
        return len(self.imglist)

4)BatchDataset类

此类与上述类类似。

class BatchDataset(Dataset):
    def __init__(self, transform=None, dataloader=default_loader):
        self.transform = transform
        self.dataloader = dataloader
        with open('train.txt', 'r') as fid:
            self.imglist = fid.readlines()
        self.labels = []
        for line in self.imglist:
            image_path, label = line.strip().split()
            self.labels.append(int(label))
        self.labels = np.array(self.labels)
        self.labels = torch.LongTensor(self.labels)

    def __getitem__(self, index):
        image_name, label = self.imglist[index].strip().split()
        image_path = image_name
        img = self.dataloader(image_path) #载入数据
        img = self.transform(img)
        label = int(label)
        label = torch.LongTensor([label])
        return [img, label]

    def __len__(self):
        return len(self.imglist)

5)BalancedBatchSampler类
此代码没有初始化父类,可能是用不到父类的变量。

  • 获取所有样本的
class BalancedBatchSampler(BatchSampler):
    def __init__(self, dataset, n_classes, n_samples):
    '''
	获取每类样本对应的索引,用字典保存,并将每类的索引打乱,因为此采样器返回的就是索引列表,用于在Dataset中获取样本的,相当于其中的index
	'''
        self.labels = dataset.labels #所有样本的labels
        self.labels_set = list(set(self.labels.numpy())) #0~199,如果是1~200会报错的
        self.label_to_indices = {
    
    label: np.where(self.labels.numpy() == label)[0] #返回每类中的样本对应的index,字典
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l]) #将每类对应的索引打乱

        self.used_label_indices_count = {
    
    label: 0 for label in self.labels_set}  ##每类样本使用过的数量
        self.count = 0 #用过的图片数量,用于统计看够不够下一个batch用

        self.n_classes = n_classes
        self.n_samples = n_samples
        self.dataset = dataset

        self.batch_size = self.n_samples * self.n_classes#此处保存一个batch样本的数量

    def __iter__(self):
        self.count = 0 #只用关心样本使用一次的事情,也就是一个epoch后就归零了
        while self.count + self.batch_size < len(self.dataset): #也就是使用过的图片数量再加上一个batch仍然小于总数,那么可以继续提供一个batch的图片
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False) #选类别,不放回的抽,也就是抽出来的不能有重复
            indices = []
            for class_ in classes: # 1 , 3 ,4
                indices.extend(self.label_to_indices[class_][ #label_to_indices是一个字典,{class:[index]}
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples]) #顺序获取每类中的n个样本
                self.used_label_indices_count[class_] += self.n_samples #每类样本使用过的数量增加此次使用的样本数量
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]): #使用过的加上下次的样本沟用不够,大于表示不够下次使用了
                    np.random.shuffle(self.label_to_indices[class_])#不够下次使用,那就将其重新打乱,并将每类的使用数量归零
                    self.used_label_indices_count[class_] = 0
            yield indices #每类都获取到了后,就送出去,送出去的是样本的索引
            self.count += self.n_classes * self.n_samples #增加本批样本数量

    def __len__(self):
        return len(self.dataset) // self.batch_size

6)具体应用:

train_dataset = BatchDataset(transform=transforms.Compose([
        transforms.Resize([512, 512]),
        transforms.RandomCrop([448, 448]),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225)
        )]))

    train_sampler = BalancedBatchSampler(train_dataset, args.n_classes, args.n_samples) #用于设置每批样本的来源,其返回的是样本的索引indices
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_sampler=train_sampler,#不再设置batch_size,使用batch_sampler
        num_workers=args.workers, pin_memory=True) #num_works是线程,pin_memory不太懂,没看懂

二、总结

1.读取图片,考虑读取失败的情况,并且要考虑失败后进行记录,使用try…except…的结构;
2.样本的获取分为采样以及获取实例两步,正常情况下,通过设置batch_size即可不用设置采样器,只需要设置Dataset数据集即可;
3.设置样本采样器,继承自torch.utils.data.sampler.BatchSampler,之后实现三个函数,分别是__init__()、iter()、以及__len__()等,其中的iter函数中使用yield生成一个生成器,不断送出采样的索引列表;
4. 设置数据集Dataset,继承自torch.utils.data.Dataset类,之后实现三个函数,分别是__init__()、getitem()、len(),getitem()函数是返回实例与标签的字典。

猜你喜欢

转载自blog.csdn.net/YJYS_ZHX/article/details/113518231