keras模型处理mnist

import numpy as np

np.random.seed(1337)
from tensorflow.examples.tutorials.mnist import input_data
# mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.optimizers import RMSprop

'''
models.Sequential,用来一层一层一层的去建立神经层;
layers.Dense 意思是这个神经层是全连接层。
layers.Activation 激励函数。
optimizers.RMSprop 优化器采用 RMSprop,加速神经网络训练方法。
'''

#X_train, y_train, X_test, y_test = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# 下载数据集

# data pre-processing
print(X_train.shape[0])
X_train = X_train.reshape(X_train.shape[0], -1) / 255.  # normalize
X_test = X_test.reshape(X_test.shape[0], -1) / 255.  # normalize
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10)  # one_hot 编码

model = Sequential([
    Dense(32, input_dim=784),
    Activation('relu'),
    Dense(10),
    Activation('softmax'),
])
'''
第一段就是加入 Dense 神经层。32 是输出的维度,784 是输入的维度。 
第一层传出的数据有 32 个 feature,传给激励单元,激励函数用到的是 relu 函数。 
经过激励函数之后,就变成了非线性的数据。 
然后再把这个数据传给下一个神经层,这个 Dense 我们定义它有 10 个输出的 feature。
同样的,此处不需要再定义输入的维度,因为它接收的是上一层的输出。 
接下来再输入给下面的 softmax 函数,用来分类。
'''

rmsprop = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)
'''
接下来用 RMSprop 作为优化器,它的参数包括学习率等,可以通过修改这些参数来看一下模型的效果。
'''
# 激活模型
model.compile(optimizer=rmsprop,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

print('Training ------------')
# Another way to train the model
# 训练
model.fit(X_train, y_train, epochs=3, batch_size=32)
'''
 fit 函数,把训练集的 x 和 y 传入之后,epochs 表示把整个数据训练多少次,batch_size 每批处理32个。
'''

# 测试模型
print('\nTesting ------------')
# Evaluate the model with the metrics we defined earlier
loss, accuracy = model.evaluate(X_test, y_test)

print('test loss: ', loss)
print('test accuracy: ', accuracy)

猜你喜欢

转载自blog.csdn.net/sunshunli/article/details/81358120