一、数据准备
总结:
RandomDataset :用于验证 (val)
BatchDataset:用于训练 (train)
BalancedBatchSampler:决定如何采样样本,不是简单的在Dataloader中设置一个batch_size了
1)导入的包类:
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler#此处的BatchSampler相当于在Dataloader中设置的batch_size
2)读取图片函数
学习点:考虑读取图片失败情况,使用try…except结构,并将读取失败的图路径存储下来,并返回一个全为白色的同尺度新图。
def default_loader(path):
'''
打开图片并转化为RGB,打开失败,则记录下来,并返回一个新图
'''
try:
img = Image.open(path).convert('RGB')
except:
with open('read_error.txt', 'a') as fid:
fid.write(path+'\n')
return Image.new('RGB', (224,224), 'white')
return img
3)RandomDataset类
此类用于选取指定index的样本,返回的是一张图片以及对应的标签。
批量获取是在Dataloader中设置的。
class RandomDataset(Dataset):
def __init__(self, transform=None, dataloader=default_loader):#此处dataloader用不上
self.transform = transform
self.dataloader = dataloader
with open('val.txt', 'r') as fid:#将图片路径以及标签读取出来
self.imglist = fid.readlines()
def __getitem__(self, index):
image_name, label = self.imglist[index].strip().split() #获取对应的路径以及标签
image_path = image_name
img = self.dataloader(image_path)
img = self.transform(img)
label = int(label) #特别注意,要将label设置为int类型
label = torch.LongTensor([label])
return [img, label] #注意官网推荐使用字典{'image':img,'label':label}
def __len__(self):
return len(self.imglist)
4)BatchDataset类
此类与上述类类似。
class BatchDataset(Dataset):
def __init__(self, transform=None, dataloader=default_loader):
self.transform = transform
self.dataloader = dataloader
with open('train.txt', 'r') as fid:
self.imglist = fid.readlines()
self.labels = []
for line in self.imglist:
image_path, label = line.strip().split()
self.labels.append(int(label))
self.labels = np.array(self.labels)
self.labels = torch.LongTensor(self.labels)
def __getitem__(self, index):
image_name, label = self.imglist[index].strip().split()
image_path = image_name
img = self.dataloader(image_path) #载入数据
img = self.transform(img)
label = int(label)
label = torch.LongTensor([label])
return [img, label]
def __len__(self):
return len(self.imglist)
5)BalancedBatchSampler类
此代码没有初始化父类,可能是用不到父类的变量。
- 获取所有样本的
class BalancedBatchSampler(BatchSampler):
def __init__(self, dataset, n_classes, n_samples):
'''
获取每类样本对应的索引,用字典保存,并将每类的索引打乱,因为此采样器返回的就是索引列表,用于在Dataset中获取样本的,相当于其中的index
'''
self.labels = dataset.labels #所有样本的labels
self.labels_set = list(set(self.labels.numpy())) #0~199,如果是1~200会报错的
self.label_to_indices = {
label: np.where(self.labels.numpy() == label)[0] #返回每类中的样本对应的index,字典
for label in self.labels_set}
for l in self.labels_set:
np.random.shuffle(self.label_to_indices[l]) #将每类对应的索引打乱
self.used_label_indices_count = {
label: 0 for label in self.labels_set} ##每类样本使用过的数量
self.count = 0 #用过的图片数量,用于统计看够不够下一个batch用
self.n_classes = n_classes
self.n_samples = n_samples
self.dataset = dataset
self.batch_size = self.n_samples * self.n_classes#此处保存一个batch样本的数量
def __iter__(self):
self.count = 0 #只用关心样本使用一次的事情,也就是一个epoch后就归零了
while self.count + self.batch_size < len(self.dataset): #也就是使用过的图片数量再加上一个batch仍然小于总数,那么可以继续提供一个batch的图片
classes = np.random.choice(self.labels_set, self.n_classes, replace=False) #选类别,不放回的抽,也就是抽出来的不能有重复
indices = []
for class_ in classes: # 1 , 3 ,4
indices.extend(self.label_to_indices[class_][ #label_to_indices是一个字典,{class:[index]}
self.used_label_indices_count[class_]:self.used_label_indices_count[
class_] + self.n_samples]) #顺序获取每类中的n个样本
self.used_label_indices_count[class_] += self.n_samples #每类样本使用过的数量增加此次使用的样本数量
if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]): #使用过的加上下次的样本沟用不够,大于表示不够下次使用了
np.random.shuffle(self.label_to_indices[class_])#不够下次使用,那就将其重新打乱,并将每类的使用数量归零
self.used_label_indices_count[class_] = 0
yield indices #每类都获取到了后,就送出去,送出去的是样本的索引
self.count += self.n_classes * self.n_samples #增加本批样本数量
def __len__(self):
return len(self.dataset) // self.batch_size
6)具体应用:
train_dataset = BatchDataset(transform=transforms.Compose([
transforms.Resize([512, 512]),
transforms.RandomCrop([448, 448]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)]))
train_sampler = BalancedBatchSampler(train_dataset, args.n_classes, args.n_samples) #用于设置每批样本的来源,其返回的是样本的索引indices
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_sampler=train_sampler,#不再设置batch_size,使用batch_sampler
num_workers=args.workers, pin_memory=True) #num_works是线程,pin_memory不太懂,没看懂
二、总结
1.读取图片,考虑读取失败的情况,并且要考虑失败后进行记录,使用try…except…的结构;
2.样本的获取分为采样以及获取实例两步,正常情况下,通过设置batch_size即可不用设置采样器,只需要设置Dataset数据集即可;
3.设置样本采样器,继承自torch.utils.data.sampler.BatchSampler,之后实现三个函数,分别是__init__()、iter()、以及__len__()等,其中的iter函数中使用yield生成一个生成器,不断送出采样的索引列表;
4. 设置数据集Dataset,继承自torch.utils.data.Dataset类,之后实现三个函数,分别是__init__()、getitem()、len(),getitem()函数是返回实例与标签的字典。