MNIST 的全称是 Modified National Institute of Standards and Technology database。
MNIST 是一个手写数字数据集,包含 70000 个灰度图像,这些图像分辨率为 28x28 像素,数字从 0 到 9 不等。它被广泛用于机器学习和深度学习的测试和评估。
MNIST数据集是一个常用的手写数字识别数据集,包含了大约 60000 张训练集图片和 10000 张测试集图片,每张图片都是 28 像素 * 28 像素的灰度图像。这些图像都经过了预处理和标准化,使得每张图像都被表示为一个行向量,其中每个元素的值都在 0 到 1 之间。MNIST数据集中的每个图像都标注有对应的数字,因此该数据集通常用于训练和评估机器学习算法和模型的性能,尤其是对手写数字识别算法和模型的评估。可以从MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges下载
import mnist
#加载Keras的MNIST数据集
from tensorflow.keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 测试集:图像和标签,一一对应。
# 使用MNIST数据集,它是机器学习领域的一个经典数据集数据集。
# 手写数字的灰度图像(28像素×28像素)划分到10个类别中(从0到9)。包含60 000张训练图像和10 000张测试图像。
t = train_images.shape
print(t)
l = len(train_labels)
print(l)
tl = train_labels
print(tl)
print(test_images.shape)
print(len(test_labels))
print(test_labels)
'''
(60000, 28, 28)
60000
[5 0 4 ... 5 6 8]
(10000, 28, 28)
10000
[7 2 1 ... 4 5 6]
'''
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 神经网络架构 layer层包含2个Dense层,密集连接(也叫全连接)神经层
model = keras.Sequential([
layers.Dense(512, activation="relu"), #
layers.Dense(10, activation = "softmax") # 10路 softmax分类层,返回一个10概率值(总和1)组成的数字
])
# 指定编译。
model.compile(
optimizer="rmsprop", #优化器
loss="sparse_categorical_crossentropy",#损失函数
metrics=["accuracy"]#指标 精度
)
# 准备图像数据。 缩放所有值从[0,255]调整为[0,1]
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype("float32") / 255
# 拟合模型
model.fit(train_images, train_labels, epochs=5, batch_size=128)
#利用模型进行预测
test_digits = test_images[0:10]
predictions = model.predict(test_digits)
print(predictions[0])
print(predictions[0].argmax()) # 第一个测试数字在索引为7时的概率最大
print(predictions[0][7])
print(test_labels[0])
# 在新数据上评估模型。损失值loss、精度acc
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"test_acc: {test_acc}")
'''
Epoch 1/5
469/469 [==============================] - 3s 4ms/step - loss: 0.2617 - accuracy: 0.9245
Epoch 2/5
469/469 [==============================] - 2s 4ms/step - loss: 0.1069 - accuracy: 0.9692
Epoch 3/5
469/469 [==============================] - 2s 4ms/step - loss: 0.0706 - accuracy: 0.9788
Epoch 4/5
469/469 [==============================] - 2s 4ms/step - loss: 0.0514 - accuracy: 0.9845
Epoch 5/5
469/469 [==============================] - 2s 4ms/step - loss: 0.0381 - accuracy: 0.9887
1/1 [==============================] - 0s 51ms/step
[2.5678100e-08 2.2875359e-09 6.8501649e-06 8.6247064e-06 1.5738954e-11
6.4071308e-08 1.9494181e-12 9.9998415e-01 1.0324158e-08 2.6900142e-07]
7
0.99998415
7
313/313 [==============================] - 1s 2ms/step - loss: 0.0640 - accuracy: 0.9800
test_acc: 0.9800000190734863
'''