Keras练手项目之MNIST手写数字识别

工作需求需要跑一个深度学习的程序,但是不是在以前用过的Ubuntu和Win10上,所以很难受....

这个项目需要搭建:TensorFlow、Keras

如果搭建环境可以参考这个:https://blog.csdn.net/zlase/article/details/78572041

搭建好了之后发现了错误,不能加载其手写体的数据集.....

参考了这篇博客,确切说是转了一下这个博客,搞定了:https://www.jianshu.com/p/b5e3053a2716


一、Keras简介

Keras是一个高层神经网络API,由纯Python编写,默认基于TensorFlow作为计算后端,非常适合快速开发出一个深度学习项目原型。

Keras目前兼容Python2.7-3.6,当然在我实际操作中,发现还要注重TensorFlow(以下简称为TF)版本,这里也是巨坑,在写这个Demo时候,TF最新版本为1.4,若用最新版本TF,安装Keras之后,在导入Keras过程中可能会一直报错。我通过将TF版本降至1.3成功解决不兼容问题。

二、MNIST手写数字识别

MNIST手写数字识别可以理解为深度学习领域的HelloWorld,mnist数据是手写数字的数据集合,训练集规模为60000,测试集为10000
更多详细内容可以查看官网 http://yann.lecun.com/exdb/mnist/

本文内容包括

  • 加载数据的方法
  • 搭建神经网络分类算法
  • 对Keras实现算法程序中部分方法参数解析

三、神经网络算法分类

1.加载数据

在Keras中通过mnist.load_data()方法实现加载数据,然而不幸的是调用该方法法时候多数情况会出现下面这样一个结果,当然,这个多半和网络有关

Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 1318, in do_open
    encode_chunked=req.has_header('Transfer-encoding'))
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/http/client.py", line 1239, in request
    self._send_request(method, url, body, headers, encode_chunked)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/http/client.py", line 1285, in _send_request
    self.endheaders(body, encode_chunked=encode_chunked)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/http/client.py", line 1234, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/http/client.py", line 1026, in _send_output
    self.send(msg)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/http/client.py", line 964, in send
    self.connect()
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/http/client.py", line 1400, in connect
    server_hostname=server_hostname)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/ssl.py", line 401, in wrap_socket
    _context=self, _session=session)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/ssl.py", line 808, in __init__
    self.do_handshake()
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/ssl.py", line 1061, in do_handshake
    self._sslobj.do_handshake()
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/ssl.py", line 683, in do_handshake
    self._sslobj.do_handshake()
ssl.SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:749)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/utils/data_utils.py", line 221, in get_file
    urlretrieve(origin, fpath, dl_progress)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 248, in urlretrieve
    with contextlib.closing(urlopen(url, data)) as fp:
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 223, in urlopen
    return opener.open(url, data, timeout)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 526, in open
    response = self._open(req, data)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 544, in _open
    '_open', req)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 504, in _call_chain
    result = func(*args)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 1361, in https_open
    context=self._context, check_hostname=self._check_hostname)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 1320, in do_open
    raise URLError(err)
urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:749)>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/xhades/Documents/github/PythonEngineer/mnist/kerasmnist.py", line 69, in <module>
    mnist.load_data()
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/datasets/mnist.py", line 17, in load_data
    file_hash='8a61469f7ea1b51cbae51d4f78837e45')
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/utils/data_utils.py", line 223, in get_file
    raise Exception(error_msg.format(origin, e.errno, e.reason))
Exception: URL fetch failure on https://s3.amazonaws.com/img-datasets/mnist.npz: None -- [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:749)

Process finished with exit code 1

先来看一下load_data()长什么样子吧

def load_data(path='mnist.npz'):
    """Loads the MNIST dataset.

    # Arguments
        path: path where to cache the dataset locally
            (relative to ~/.keras/datasets).

    # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """
    path = get_file(path,
                    origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
                    file_hash='8a61469f7ea1b51cbae51d4f78837e45')
    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()
    return (x_train, y_train), (x_test, y_test)

打开之后发现又调用了get_file()方法,这里就不详细解释这个方法了,他主要实现了检查路径下是否有文件,没有则下载的功能。

基于这个理解,我手动下载了数据集数据下载链接,并且重写了load_data()方法

# 内置load_data() 多次加载数据都是失败 于是下载数据后 自定义方法
def load_data(path="MNIST_data/mnist.npz"):
    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()
    return (x_train, y_train), (x_test, y_test)

2.构建序贯模型网络结构

# 构建序贯模型
def train():
    model = Sequential()
    model.add(Dense(500,input_shape=(784,)))  # 输入层, 28*28=784
    model.add(Activation('tanh'))
    model.add(Dropout(0.3))   # 30% dropout
    model.add(Dense(300))  # 隐藏层, 300
    model.add(Activation('tanh'))
    model.add(Dropout(0.3))   # 30% dropout
    model.add(Dense(10))
    model.add(Activation('softmax'))

    # 编译模型
    sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)   
    model.compile(loss = 'categorical_crossentropy', optimizer=sgd)
    return model

Keras的Sequential模型,也即序贯模型,也就是单输入单输出,一条路通到底,层与层之间只有相邻关系,跨层连接统统没有。这种模型编译速度快,操作上也比较简单。在Keras 0.x中还有图模型,但是Keras1和Keras2中已被移除,只保留序贯模型

Dense就是常用的全连接层,500代表该层的输出维度,784是像素维度即28*28

Dropout层为输入数据施加Dropout。Dropout将在训练过程中每次更新参数时按一定概率(rate)随机断开输入神经元,Dropout层用于防止过拟合。

Activation激活函数选择tanh,activation:激活函数,如果不指定该参数,将不会使用任何激活函数(即使用线性激活函数:a(x)=x)

最后用softmax函数将预测结果转换为标签的概率值

3.训练及测试准确率

def run():
    (x_train, y_train), (x_test, y_test) = load_data()
    X_train = x_train.reshape(x_train.shape[0], x_train.shape[1] * x_train.shape[2])
    X_test = x_test.reshape(x_test.shape[0], x_test.shape[1] * x_test.shape[2])

    Y_train = (np.arange(10) == y_train[:, None]).astype(int)
    Y_test = (np.arange(10) == y_test[:, None]).astype(int)

    model = train()
    model.fit(X_train, Y_train, batch_size=200, epochs=10, shuffle=True, verbose=1, validation_split=0.3)
    print("Start Test.....\n")
    scores = model.evaluate(X_test, Y_test, batch_size=200, verbose=1)
    print("The Test Loss: %f" % scores[0])
训练主要是调用 fit 方法
准确率测试 evaluate 方法









猜你喜欢

转载自blog.csdn.net/Zlase/article/details/80668970
今日推荐