如何加载mnist和fashion-mnist数据集

版权声明:版权归世界上所有无产阶级所有 https://blog.csdn.net/qq_41776781/article/details/86767374
mnist手写体数据集是人工智能中最简单, 数据集下载的连接是:

mnist数据集下载

fashion-mnist数据集的存储和mnist数据集的存储形式一样,数据集下载的连接是:

fashion-mnist数据集下载

  • 程序讲解:程序分成两个部分
  • 首先使用load_mnist() 是加载mnist数据集或者是fashion-mnist数据集,两者格式完全相同,所以建议是将其保存到不同的文件夹下,通过指定文件夹选择加载的数据集
  • 其次for循环,加载保存图像,并打印对应的标签
  • 最后save_images按照框架保存图像,记得将图像的数值范围进行修改
def load_mnist():
    # 2019 可以选择不同的数据集
    # data_dir = "../Dataset/fashion-mnist/"
    data_dir = "../Dataset/mnist_data/"

    def extract_data(filename, num_data, head_size, data_size):
        with gzip.open(filename) as bytestream:
            bytestream.read(head_size)
            buf = bytestream.read(data_size * num_data)
            data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
        return data

    data = extract_data(data_dir + 'train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
    trX = data.reshape((60000, 28, 28, 1))

    data = extract_data(data_dir + 'train-labels-idx1-ubyte.gz', 60000, 8, 1)
    trY = data.reshape((60000))

    data = extract_data(data_dir + 't10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
    teX = data.reshape((10000, 28, 28, 1))

    data = extract_data(data_dir + 't10k-labels-idx1-ubyte.gz', 10000, 8, 1)
    teY = data.reshape((10000))

    trY = np.asarray(trY)
    teY = np.asarray(teY)

    X = np.concatenate((trX, teX), axis=0)
    y = np.concatenate((trY, teY), axis=0).astype(np.int)

    data_index = np.arange(X.shape[0])
    print("*****************dataX**************", len(X))
    np.random.shuffle(data_index)
    # data_index = data_index[:128]
    X = X[data_index, :, :, :]
    y = y[data_index]
    y_vec = np.zeros((len(y), 10), dtype=np.float)
    for i, label in enumerate(y):
        y_vec[i, y[i]] = 1.0

    return X / 255., y_vec
def merge(images):
    size = [8,8]
    if isinstance(images, list):
        images = np.array(images)
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j * h: j * h + h, i * w: i * w + w, :] = image
    return img
def save_images(images, image_path):
    change_image_formal = (images+1.)/2
    image = np.squeeze(merge(images))
    save_image = scipy.misc.imsave(path, image)
    return save_image
data_X, data_y = load_mnist(self.dataset_name)
# print("self.data_X, self.data_y",self.data_X,self.data_y)
result_dir  = "mnist"
model_name  = "image-2-image"

#测试加载的数据集和标签是否对应 以测试成功
for idx in range(5):
    batch_size = 64
    batch_images = data_X[idx * batch_size:(idx + 1) * batch_size]
    # 2019 2 3不执行
    batch_images_y = data_y[idx * batch_size:(idx + 1) * batch_size]

  
    manifold_h = int(np.floor(np.sqrt(batch_size))) 
    manifold_w = int(np.floor(np.sqrt(batch_size)))  
    save_images(batch_images[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
                './' + result_dir + '/' + model_name + '_real_image_{:04d}.png'.format(
                    idx))
    print("batch_images_y的数值是:", batch_images_y)

结果展示:

加载fashion-mnist

猜你喜欢

转载自blog.csdn.net/qq_41776781/article/details/86767374