torch.utils.data.Sampler

torch.utils.data.DataLoader中有一个参数sampler,其默认值为None。sampler参数和batch_sampler参数允许用户自己指定数据的加载顺序与采样方式。

返回样本索引。

顺序采样

class SequentialSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
 
    def __iter__(self):
        return iter(range(len(self.data_source)))
 
    def __len__(self):
        return len(self.data_source)

均匀采样

如果有n个类别,那每个batch中,每个类别有k个样本,batch size=n*k

class BalancedBatchSampler(BatchSampler):
    def __init__(self, labels, args):
        # n_classes, n_samples = args.classes, args.batch_k
        self.labels = labels
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {
    
    label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {
    
    label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = args.classes
        self.n_samples = args.batch_k
        self.n_dataset = len(self.labels)
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_dataset:
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return self.n_dataset // self.batch_size

调用方法

batch_sampler_train = BalancedBatchSampler(train_dataset.train_labels, args)
batch_sampler_test = BalancedBatchSampler(test_dataset.test_labels, args)

kwargs = {
    
    'num_workers': 0, 'pin_memory': True} if cuda else {
    
    }
train_loader = torch.utils.data.DataLoader(batch_sampler=batch_sampler_train, dataset=train_dataset, **kwargs)
test_loader = torch.utils.data.DataLoader(batch_sampler=batch_sampler_test, dataset=test_dataset, **kwargs)

注意此时不需要在torch.utils.data.DataLoader中写batch_size、shuffle
yield生成器https://www.runoob.com/python3/python3-iterator-generator.html

https://blog.csdn.net/SweetWind1996/article/details/105328385

猜你喜欢

转载自blog.csdn.net/weixin_42764932/article/details/112651868
今日推荐