Mnist how to load data sets and fashion-mnist

Disclaimer: All belongs to the world proletariat all https://blog.csdn.net/qq_41776781/article/details/86767374
mnist手写体数据集是人工智能中最简单, 数据集下载的连接是:

mnist data set download

fashion-mnist数据集的存储和mnist数据集的存储形式一样,数据集下载的连接是:

dataset download fashion-mnist

  • Program explanation: the program is divided into two portions
  • First, using load_mnist () is a set of data or a load MNIST fashion-mnist data set, both formats are identical, it is proposed to save it to a different folder, the folder selected by the loaded data set specified file
  • Secondly for loop, save the image loading, and the corresponding printed labels
  • Finally, according to the stored image frame save_images remember the value range of the image to be modified
def load_mnist():
    # 2019 可以选择不同的数据集
    # data_dir = "../Dataset/fashion-mnist/"
    data_dir = "../Dataset/mnist_data/"

    def extract_data(filename, num_data, head_size, data_size):
        with gzip.open(filename) as bytestream:
            bytestream.read(head_size)
            buf = bytestream.read(data_size * num_data)
            data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
        return data

    data = extract_data(data_dir + 'train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
    trX = data.reshape((60000, 28, 28, 1))

    data = extract_data(data_dir + 'train-labels-idx1-ubyte.gz', 60000, 8, 1)
    trY = data.reshape((60000))

    data = extract_data(data_dir + 't10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
    teX = data.reshape((10000, 28, 28, 1))

    data = extract_data(data_dir + 't10k-labels-idx1-ubyte.gz', 10000, 8, 1)
    teY = data.reshape((10000))

    trY = np.asarray(trY)
    teY = np.asarray(teY)

    X = np.concatenate((trX, teX), axis=0)
    y = np.concatenate((trY, teY), axis=0).astype(np.int)

    data_index = np.arange(X.shape[0])
    print("*****************dataX**************", len(X))
    np.random.shuffle(data_index)
    # data_index = data_index[:128]
    X = X[data_index, :, :, :]
    y = y[data_index]
    y_vec = np.zeros((len(y), 10), dtype=np.float)
    for i, label in enumerate(y):
        y_vec[i, y[i]] = 1.0

    return X / 255., y_vec
def merge(images):
    size = [8,8]
    if isinstance(images, list):
        images = np.array(images)
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j * h: j * h + h, i * w: i * w + w, :] = image
    return img
def save_images(images, image_path):
    change_image_formal = (images+1.)/2
    image = np.squeeze(merge(images))
    save_image = scipy.misc.imsave(path, image)
    return save_image
data_X, data_y = load_mnist(self.dataset_name)
# print("self.data_X, self.data_y",self.data_X,self.data_y)
result_dir  = "mnist"
model_name  = "image-2-image"

#测试加载的数据集和标签是否对应 以测试成功
for idx in range(5):
    batch_size = 64
    batch_images = data_X[idx * batch_size:(idx + 1) * batch_size]
    # 2019 2 3不执行
    batch_images_y = data_y[idx * batch_size:(idx + 1) * batch_size]

  
    manifold_h = int(np.floor(np.sqrt(batch_size))) 
    manifold_w = int(np.floor(np.sqrt(batch_size)))  
    save_images(batch_images[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
                './' + result_dir + '/' + model_name + '_real_image_{:04d}.png'.format(
                    idx))
    print("batch_images_y的数值是:", batch_images_y)

The results show:

Load fashion-mnist

Guess you like

Origin blog.csdn.net/qq_41776781/article/details/86767374