MNIST选取特定数值的训练集

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/a19990412/article/details/83934447

简述

调用的方法跟torchvision.datasets.MNIST的方法类似。
只不过在最后面加参变量targetNum= 选定特定的数值就好了。

代码

class myMNIST(torchvision.datasets.MNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, targetNum=None):
        super(myMNIST, self).__init__(
            root,
            train=train,
            transform=transform,
            target_transform=target_transform,
            download=download)
        if targetNum != None:
            self.train_data = self.train_data[self.train_labels == targetNum]

            self.train_data = self.train_data[:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]

            self.train_labels = self.train_labels[self.train_labels == targetNum][
                                :int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]

    def __len__(self):
        if self.train:
            return self.train_data.shape[0]
        else:
            return 10000

猜你喜欢

转载自blog.csdn.net/a19990412/article/details/83934447