版权声明:用心写好你的每一篇文章,我的个人博客已上线:http://thinkgamer.cn https://blog.csdn.net/Gamer_gyt/article/details/80039242
打开微信扫一扫,关注微信公众号【数据与算法联盟】
转载请注明出处: http://blog.csdn.net/gamer_gyt
博主微博: http://weibo.com/234654758
Github: https://github.com/thinkgamer
MNIST是在机器学习领域中的一个经典问题。该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字,其中数字的范围从0到9.
数据集 | 目的 |
---|---|
data_sets.train | 55000 组 图片和标签, 用于训练。 |
data_sets.validation | 5000 组 图片和标签, 用于迭代验证训练的准确性。 |
data_sets.test | 10000 组 图片和标签, 用于最终测试训练的准确性。 |
数据集简介
MNIST数据集加载有两种办法,第一是直接从网上下载,第二是下载到本地进行load(跟第一种类似,只不过是事先下载好,从本地进行加载)。从网上下载到本地方式如下:
# 加载mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
print("load finish")
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
print(type(mnist))
输出为:
load finish
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
<class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>
print("MNIST 训练集数据条数:" ,mnist.train.num_examples)
print("MNIST 测试集数据条数:" ,mnist.test.num_examples)
train_img = mnist.train.images
train_label = mnist.train.labels
print("训练集类型:",type(train_img))
print("训练集维度:",train_img.shape)
test_img = mnist.test.images
test_label = mnist.test.labels
print("测试集类型:",type(test_img))
print("测试集维度:",test_img.shape)
输出为:
MNIST 训练集数据条数: 55000
MNIST 测试集数据条数: 10000
训练集类型: <class 'numpy.ndarray'>
训练集维度: (55000, 784)
测试集类型: <class 'numpy.ndarray'>
测试集维度: (10000, 784)
打开当前运行代码的目录,我们会发现一个MNIST_data的文件夹,里边包含的文件如下:
文件 | 内容 |
---|---|
train-images-idx3-ubyte.gz | 训练集图片 - 55000 张 训练图片, 5000 张 验证图片 |
train-labels-idx1-ubyte.gz | 训练集图片对应的数字标签 |
t10k-images-idx3-ubyte.gz | 测试集图片 - 10000 张 图片 |
t10k-labels-idx1-ubyte.gz | 测试集图片对应的数字标签 |
使用next_batch函数加载指定条数的数据集
# 关于next_batch函数
batchSize = 100
batch_x,batch_y = mnist.train.next_batch(batch_size=batchSize)
print(batch_x.shape)
print(batch_y.shape)
输出为:
(100, 784)
(100, 10)
打开微信扫一扫,加入数据与算法交流大群