基于mnist数据库的CNN卷积神经网络手写数字识别MATLAB仿真

目录

一、理论基础

二、核心程序

三、仿真结论


一、理论基础


      手写数字识别是计算机视觉领域的一个重要问题,它在很多应用中都有广泛的应用,如自动化识别、自然语言处理、人机交互等。基于MNIST数据库的CNN卷积神经网络手写数字识别是其中一种常用的算法。本文将从数学模型和实现步骤两个方面,详细介绍基于MNIST数据库的CNN卷积神经网络手写数字识别算法。

数学模型

       卷积神经网络(Convolutional Neural Network,CNN)是一种前馈神经网络,主要用于处理具有网格状结构的数据,如图像和声音。卷积神经网络具有多层结构,其中每个层都包含许多卷积核(Convolutional Kernel),这些卷积核可以提取输入数据的不同特征。

       假设我们有一个输入图像$x$,大小为$W\times H\times D$,其中$W$是图像的宽度,$H$是图像的高度,$D$是图像的深度,也称为通道数。第一个卷积层将对输入图像进行卷积操作,使用卷积核对输入图像进行卷积运算,得到一个输出特征图$z$。输出特征图的大小为$(W-K+2P)/S+1\times (H-K+2P)/S+1\times F$,其中$K$是卷积核的大小,$P$是填充大小,$S$是步幅,$F$是卷积核的数量。输出特征图中的每个元素$z_{i,j,k}$表示卷积核在输入图像上的一个局部区域内提取的特征。

      对于第一个卷积层的输出特征图$z$,我们可以对其进行池化操作,将每个卷积核提取的特征进行压缩,得到一个更小的特征图。常用的池化操作有最大池化(Max Pooling)和平均池化(Average Pooling)。最大池化操作是在输入特征图上进行滑动窗口操作,将窗口内的最大值作为输出值。平均池化操作则是将窗口内的所有值取平均值作为输出值。池化操作可以减小特征图的大小,同时提高特征的鲁棒性和不变性。

       卷积神经网络还包含多个全连接层(Fully Connected Layer),用于对特征进行分类。全连接层将特征图展开成一维向量,然后进行线性变换和非线性激活操作,得到输出结果。通常使用Softmax函数将输出值映射到概率分布上,得到每个类别的概率值。

实现步骤

数据准备
       MNIST(Mixed National Institute of Standards and Technology database)是一个常用的手写数字识别数据库,包含60000个训练样本和10000个测试样本。我们可以使用Python的Keras库中的mnist.load_data()函数来加载MNIST数据库。加载后的数据集包含两个元组,分别为训练集和测试集,每个元组包含一个图像数组和一个标签数组。

模型构建
       我们可以使用Keras库来构建CNN模型。在Keras中,我们可以使用Sequential模型来构建多层神经网络。首先,我们可以添加一个卷积层(Convolutional Layer),包含32个卷积核,每个卷积核的大小为3×3,步幅为1,激活函数为ReLU。然后,我们可以添加一个最大池化层(Max Pooling Layer),池化窗口大小为2×2,步幅为2。接下来,我们再添加一个卷积层,包含64个卷积核,每个卷积核的大小为3×3,步幅为1,激活函数为ReLU。再添加一个最大池化层,池化窗口大小为2×2,步幅为2。最后,我们添加两个全连接层(Dense Layer),其中第一个全连接层包含128个神经元,激活函数为ReLU,第二个全连接层包含10个神经元,激活函数为Softmax。模型的总体结构如下所示:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d (Conv2D)              (None, 26, 26, 32)        320
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0
_________________________________________________________________
flatten (Flatten)            (None, 1600)              0
_________________________________________________________________
dense (Dense)                (None, 128)               204928
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290
=================================================================
Total params: 224,034
Trainable params: 224,034
Non-trainable params: 0
_________________________________________________________________

模型训练
        我们可以使用Keras库中的compile()函数来编译模型,指定损失函数、优化器和评估指标。对于手写数字识别问题,我们可以使用交叉熵损失函数(categorical crossentropy),使用Adam优化器(Adam optimizer),评估指标可以选择准确率(accuracy)。然后,我们可以使用fit()函数来训练模型,指定训练集和测试集、批量大小、训练轮数等参数。

模型评估
        训练完成后,我们可以使用evaluate()函数来评估模型在测试集上的性能,计算损失值和准确率。我们也可以使用predict()函数来对新的手写数字图像进行分类预测。

二、核心程序

function net = cnntrain(net, x, y, opts)%net为网络,x是数据,y为训练目标,opts为优化参数
    %m为图片样本的数量
    m = size(x, 3);
    %batchsize为批训练时,一批所含图片样本数
    numbatches = m / opts.batchsize;%分批训练,得到训练批数
    if rem(numbatches, 1) ~= 0
        error('numbatches not integer');
    end
    net.rL = [];%rL是干嘛的?
    for i = 1 : opts.numepochs%训练迭代
        %显示训练到第几个epoch,一共多少个epoch
        disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)]);
        tic;%每个epoch的开始计时
        %Matlab自带函数randperm(n)产生1到n的整数的无重复的随机排列,利用它就可以得到无重复的随机数。
        %生成m(图片数量)1~n整数的随机无重复排列,用于打乱训练顺序;
        kk = randperm(m);
        for l = 1 : numbatches%训练每个batch
            %得到训练信号,一个样本是一层 x(:,:,sampleOrder),每次训练,取50个样本;
            batch_x = x(:, :, kk( (l - 1) * opts.batchsize + 1 : l * opts.batchsize) );
            %教师信号,一个样本是 一列
            batch_y = y(:,    kk( (l - 1) * opts.batchsize + 1 : l * opts.batchsize) );
            %创建前馈网络
            net = cnnff(net, batch_x);
            %bp训练
            net = cnnbp(net, batch_y);
            %运用优化网格?
            net = cnnapplygrads(net, opts);
            if isempty(net.rL)
                net.rL(1) = net.L;
            end
            net.rL(end + 1) = 0.99 * net.rL(end) + 0.01 * net.L;
        end
        toc;%每个epoch的结束计时
    end
    
end
up2142

三、仿真结论

猜你喜欢

转载自blog.csdn.net/ccsss22/article/details/131361845