版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/a19990412/article/details/83934447
简述
调用的方法跟torchvision.datasets.MNIST
的方法类似。
只不过在最后面加参变量targetNum=
选定特定的数值就好了。
代码
class myMNIST(torchvision.datasets.MNIST):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, targetNum=None):
super(myMNIST, self).__init__(
root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
if targetNum != None:
self.train_data = self.train_data[self.train_labels == targetNum]
self.train_data = self.train_data[:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]
self.train_labels = self.train_labels[self.train_labels == targetNum][
:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]
def __len__(self):
if self.train:
return self.train_data.shape[0]
else:
return 10000