torch.utils.data.DataLoader中有一个参数sampler,其默认值为None。sampler参数和batch_sampler参数允许用户自己指定数据的加载顺序与采样方式。
返回样本索引。
顺序采样
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
均匀采样
如果有n个类别,那每个batch中,每个类别有k个样本,batch size=n*k
class BalancedBatchSampler(BatchSampler):
def __init__(self, labels, args):
# n_classes, n_samples = args.classes, args.batch_k
self.labels = labels
self.labels_set = list(set(self.labels.numpy()))
self.label_to_indices = {
label: np.where(self.labels.numpy() == label)[0]
for label in self.labels_set}
for l in self.labels_set:
np.random.shuffle(self.label_to_indices[l])
self.used_label_indices_count = {
label: 0 for label in self.labels_set}
self.count = 0
self.n_classes = args.classes
self.n_samples = args.batch_k
self.n_dataset = len(self.labels)
self.batch_size = self.n_samples * self.n_classes
def __iter__(self):
self.count = 0
while self.count + self.batch_size < self.n_dataset:
classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
indices = []
for class_ in classes:
indices.extend(self.label_to_indices[class_][
self.used_label_indices_count[class_]:self.used_label_indices_count[
class_] + self.n_samples])
self.used_label_indices_count[class_] += self.n_samples
if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
np.random.shuffle(self.label_to_indices[class_])
self.used_label_indices_count[class_] = 0
yield indices
self.count += self.n_classes * self.n_samples
def __len__(self):
return self.n_dataset // self.batch_size
调用方法
batch_sampler_train = BalancedBatchSampler(train_dataset.train_labels, args)
batch_sampler_test = BalancedBatchSampler(test_dataset.test_labels, args)
kwargs = {
'num_workers': 0, 'pin_memory': True} if cuda else {
}
train_loader = torch.utils.data.DataLoader(batch_sampler=batch_sampler_train, dataset=train_dataset, **kwargs)
test_loader = torch.utils.data.DataLoader(batch_sampler=batch_sampler_test, dataset=test_dataset, **kwargs)
注意此时不需要在torch.utils.data.DataLoader中写batch_size、shuffle
yield生成器https://www.runoob.com/python3/python3-iterator-generator.html
https://blog.csdn.net/SweetWind1996/article/details/105328385