Keras——用Keras搭建RNN分类循环神经网络

1.前言

这次我们用循环神经网络(RNN, Recurrent Neural Networks)进行分类(classification),采用MNIST数据集,主要用到SimpleRNN层。

2.用Keras搭建RNN循环神经网络

2.1.导入必要模块

import numpy as np
from keras.datasets import mnist    #手写体数据集模块
from keras.utils import np_utils
from keras.models import Sequential   #构建网络必需模块  
from keras.layers import SimpleRNN, Activation, Dense    #RNN、激活函数、全连接层模块
from keras.optimizers import Adam   #优化器模块
np.random.seed(42)   #随机数种子

2.2.超参数设置

MNIST里面的图像分辨率是28×28,为了使用RNN,我们将图像理解为序列化数据。 每一行作为一个输入单元,所以输入数据大小INPUT_SIZE = 28; 先是第1行输入,再是第2行,第3行,第4行,…,第28行输入, 这就是一张图片也就是一个序列,所以步长TIME_STEPS = 28。

TIME_STEPS = 28     #可理解为每张图片的行数
INPUT_SIZE = 28     #可理解为每张图片的列数
BATCH_SIZE = 50     #批量大小
BATCH_INDEX = 0   
OUTPUT_SIZE = 10    #输出维度大小
CELL_SIZE = 50      #经过RNN后的输出大小
LR = 0.001

2.3.数据预处理

训练数据要进行归一化处理,因为原始数据是8bit灰度图像所以需要除以255。

(X_train, y_train),(X_test, y_test) = mnist.load_data()     #拆分训练集与测试集

X_train = X_train.reshape(-1,28,28)/255    #满足输入RNN为(-1,28,28)
X_test = X_test.reshape(-1,28,28)/255
y_train = np_utils.to_categorical(y_train,num_classes=10)    #将类别向量转换为二进制(只有0和1)的矩阵类型表示
y_test = np_utils.to_categorical(y_test,num_classes=10)

2.4.搭建模型

首先添加RNN层,输入为训练数据,输出数据大小由CELL_SIZE定义。

然后添加输出层,激励函数选择softmax

model = Sequential()
model.add(SimpleRNN(
    batch_input_shape = (None,TIME_STEPS,INPUT_SIZE),
    output_dim = CELL_SIZE,
    unroll = True
))

model.add(Dense(OUTPUT_SIZE))
model.add(Activation('softmax'))

2.5.激活模型

设置优化方法,loss函数和metrics方法之后就可以开始训练了。 每次训练的时候并不是取所有的数据,只是取BATCH_SIZE个序列,或者称为BATCH_SIZE张图片,这样可以大大降低运算时间,提高训练效率。

adam = Adam(LR)
model.compile(optimizer=adam,
              loss = 'categorical_crossentropy',
              metrics=['accuracy'])

2.6.训练+测试

for step in range(10001):
  X_batch = X_train[BATCH_INDEX:BATCH_INDEX+BATCH_SIZE,:,:]
  y_batch = y_train[BATCH_INDEX:BATCH_INDEX+BATCH_SIZE,:]
  cost = model.train_on_batch(X_batch,y_batch)
  BATCH_INDEX += BATCH_SIZE
  BATCH_INDEX = 0 if BATCH_INDEX >= X_train.shape[0] else BATCH_INDEX
  if step % 500 == 0:
    cost, accuracy = model.evaluate(X_test, y_test, batch_size=y_test.shape[0],verbose=False)
    print('test cost:',cost,'test accuracy:',accuracy)

在这里插入图片描述

发布了233 篇原创文章 · 获赞 645 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/weixin_37763870/article/details/105601604