卷积神经网络和LeNet-5

卷积神经网络

卷积神经网络(CNN)是一种深度学习算法,其仿照生物视觉机制来提取特征,卷积神经网络的隐层包含卷积层池化层全连接层

卷积层

为了减少参数以防过拟合,卷积层中的神经元被分组,每组神经元共享权重并仅连接上层神经元中的一部分,也可以看成一个神经元以“扫描”的方式连接上层神经元(如下图),但这种扫描并不代表卷积层神经元与上层不同神经元的连接存在时间上的先后顺序。每组输出“图像”尺寸都与输入尺寸相同,输入和输出的“图像”数量往往不同,输出的“图像”数量取决于组数。

上述分组即卷积核,卷积核中的每一个神经元权重矩阵都一样。

下图中输入为 X + \textbf{X}^+ X+而非 X \textbf{X} X的原因是:为了使输入和输出尺寸一致对输入进行填充(参考下图中 Y \textbf{Y} Y变为 Y + \textbf{Y}^+ Y+的过程,是在输出矩阵四周增加一些不影响计算正确性的值,一个值的下标可能变化但本身与其它值的相对位置不变,图中 Y \textbf{Y} Y变为 Y + \textbf{Y}^+ Y+的过程是对图中所示卷积层的下一层的输入进行填充,不一定非要这么做)。
诶哟图丢了

池化层

池化层位于卷积层之后,对卷积层提取的特征进行选择,即从输入的一部分值中选取一个值作为对应输出。池化层的输出尺寸小于输入尺寸,但池化层的输出的“图像”数目与输入“图像”数目相同。可以理解为池化层负责以特有的方式对“图像”进行缩小,缩小后保留“图像”特征,这些特征的位置的略微差异并不影响其输出。

逆传播时,池化层神经元仅对其已经选择的卷积层神经元传递误差,即池化层神经元仅对每一部分卷积层神经元中输出值最大的一个(组)神经元传递误差,如下图所示。哪个上层神经元的输出通过池化层传递到了下一层,逆传播时梯度就传递到哪个上层神经元。

池化层神经元本身没有可训练参数,也就是不学习。
诶哟图丢了

全连接层

每一个神经元都与上层的所有神经元连接,普通神经网络的隐层就是全连接层,前几层输出的特征“图像”在被全连接层处理前需要处理为线性结构。
诶哟图丢了

LeNet-5

LeNet-5是最早的卷积神经网络之一,用作快速识别手写数字。入门的不错的选择。

from tensorflow.keras import * # tensorflow2.3.0、python3.8
from tensorflow.keras.layers import *
# LeNet结构
layers = [Conv2D(filters=20, kernel_size=(5, 5), strides=1, activation='relu', padding='same', name="Conv_1"),
          AvgPool2D(data_format='channels_last', strides=2,
                    pool_size=(2, 2), name="Pool_1"),
          Conv2D(filters=50, kernel_size=(5, 5), strides=1,
                 activation='relu', padding='same', name="Conv_2"),
          AvgPool2D(data_format='channels_last', strides=2,
                    pool_size=(2, 2), name="Pool_2"),
          layers.Flatten(),
          Dense(units=500, activation='relu', name="Dense_1"),
          Dense(units=10, activation='softmax', name="Dense_2")]
model = Sequential(layers) # 现在model对象已经是LetNet了

下图中的每一片小图片都是上面代码对应层的实际输出。最后一个池化层输出后已经看不出是数字了。
诶哟图丢了
将一幅.png图片处理成神经网络的输入值:

img = tf.io.gfile.GFile('./example.png', 'rb').read()
x = tf.image.decode_png(img, channels=1)
x = tf.cast(x, tf.float32)/255. 
# x就是一个输入值,将x加入列表后把列表给model.fit()当参数,model是在上面一段代码中声明的变量。

也可以使用tensorflow的数据集:

from tensorflow.examples.tutorials.mnist import *

猜你喜欢

转载自blog.csdn.net/dscn15848078969/article/details/115298736
今日推荐