pytorch随机采样的方法SubsetRandomSampler()

这篇文章记录一个采样器都随机地从原始的数据集中抽样数据。抽样数据采用permutation。 生成任意一个下标重排,从而利用下标来提取dataset中的数据的方法

需要的库

import torch

使用方法

这里以MNIST举例

train_dataset = dsets.MNIST(root='./data',  #文件存放路径
                            train=True,   #提取训练集
                            transform=transforms.ToTensor(),  #将图像转化为Tensor
                            download=True)

sample_size = len(train_dataset)
sampler1 = torch.utils.data.sampler.SubsetRandomSampler(
    np.random.choice(range(len(train_dataset)), sample_size))

代码详解

np.random.choice()

#numpy.random.choice(a, size=None, replace=True, p=None)
#从a(只要是ndarray都可以,但必须是一维的)中随机抽取数字,并组成指定大小(size)的数组
#replace:True表示可以取相同数字,False表示不可以取相同数字
#数组p:与数组a相对应,表示取数组a中每个元素的概率,默认为选取每个元素的概率相同。

那么这里就相当于抽取了一个全排列
torch.utils.data.sampler.SubsetRandomSampler

# 会根据后面给的列表从数据集中按照下标取元素
# class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

所以就可以了

发布了59 篇原创文章 · 获赞 19 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/weixin_43914889/article/details/104607114
今日推荐