Training and recognition of mnist handwritten digit database based on CNN network

Table of contents

1. Theoretical basis

2. Core program

3. Simulation conclusion


1. Theoretical basis

       Handwritten digit recognition is an important application in the field of machine learning, which can be applied to automatic data input, intelligent detection and other fields. The MNIST handwritten digit database is a classic data set in the field of machine learning. It contains a series of handwritten digit pictures and corresponding labels. It is one of the standard datasets for research and performance evaluation of handwritten digit recognition algorithms. This article will introduce the implementation method and steps of MNIST handwritten digit database training and recognition based on CNN deep learning network.

1. MNIST handwritten digit database

      The MNIST handwritten digit database is a commonly used handwritten digit recognition dataset, which contains 60,000 training sets and 10,000 test sets. Each sample is a grayscale image of 28x28 pixels, representing a number between 0 and 9. Among them, the training set is used to train the model, and the test set is used to test the accuracy of the model.

2. CNN deep learning network

       CNN (Convolutional Neural Network, convolutional neural network) is a deep learning network, especially suitable for image classification and recognition tasks. Different from traditional neural networks, CNN can automatically extract the features of data, so as to achieve efficient classification and recognition of images. The core components of the CNN network include convolutional layers, pooling layers, and fully connected layers.

Convolutional layer
       The convolutional layer is the core component in the CNN network, which can automatically extract the features of the input image. In the convolution layer, a set of learnable convolution kernels are used to perform convolution operations on the input image to obtain a set of feature maps. A convolution kernel is a small matrix that can be trained through the backpropagation algorithm to extract different features of the input image.

Pooling layer
      The pooling layer is another core component in the CNN network, which is used to perform dimensionality reduction operations on feature maps. In the pooling layer, operations such as maximum pooling or average pooling are usually used to compress a certain area in each feature map, thereby reducing the number of parameters and calculations in the network.

Fully Connected Layer
      The fully connected layer is the last layer in the CNN network and it is used to classify the feature maps. In the fully connected layer, the softmax function can be used to map the feature map to a probability value between 0 and 1, so as to obtain the classification result of the input image.

3. Realization of MNIST handwritten digit recognition

       The implementation of MNIST handwritten digit recognition based on CNN deep learning network mainly includes the following steps: data set loading, data preprocessing, model building, model training and model testing. The specific implementation methods of these steps will be introduced one by one below. This model contains two convolutional layers, two pooling layers and two fully connected layers. In the convolutional layer and the fully connected layer, the ReLU activation function is used to enhance the nonlinear expression ability of the model. In the last fully connected layer, a softmax function is used to map feature maps to probability values.

4. Summary

       This paper introduces the implementation method and steps of MNIST handwritten digit recognition based on CNN deep learning network. Through the introduction of data set loading, preprocessing, model building, model training, and model testing, it can help readers understand the basic principles and implementation methods of deep learning networks, as well as how to apply deep learning networks for handwritten digit recognition tasks.
 

2. Core program

clc;
clear;
close all;
warning off;
addpath(genpath(pwd));
rng('default')


inputSize  = 28 * 28;
numLabels  = 5;
hiddenSize = 200;
sparsityParam = 0.1; % desired average activation of the hidden units.
                     % (This was denoted by the Greek alphabet rho, which looks like a lower-case "p",
		             %  in the lecture notes). 
lambda = 3e-3;       % weight decay parameter       
beta = 3;            % weight of sparsity penalty term   

%% ======================================================================
%  STEP 1: Load data from the MNIST database
 

% Load MNIST database files
mnistData   = loadMNISTImages('mnist/train-images-idx3-ubyte');
mnistLabels = loadMNISTLabels('mnist/train-labels-idx1-ubyte');

% Set Unlabeled Set (All Images)

% Simulate a Labeled and Unlabeled set
labeledSet   = find(mnistLabels >= 0 & mnistLabels <= 4);
unlabeledSet = find(mnistLabels >= 5);

numTrain = round(numel(labeledSet)/2);
trainSet = labeledSet(1:numTrain);
testSet  = labeledSet(numTrain+1:end);

unlabeledData = mnistData(:, unlabeledSet);

trainData   = mnistData(:, trainSet);
trainLabels = mnistLabels(trainSet)' + 1; % Shift Labels to the Range 1-5

testData   = mnistData(:, testSet);
testLabels = mnistLabels(testSet)' + 1;   % Shift Labels to the Range 1-5

% Output Some Statistics
fprintf('# examples in unlabeled set: %d\n', size(unlabeledData, 2));
fprintf('# examples in supervised training set: %d\n\n', size(trainData, 2));
fprintf('# examples in supervised testing set: %d\n\n', size(testData, 2));

%% ======================================================================
%  STEP 2: Train the sparse autoencoder
 
theta = initializeParameters(hiddenSize, inputSize);

 
addpath minFunc/
autoencoderOptions.Method = 'lbfgs';  % Here, we use L-BFGS to optimize our cost
                                      % function. Generally, for minFunc to work, you
                                      % need a function pointer with two outputs: the
                                      % function value and the gradient. In our problem,
                                      % sparseAutoencoderCost.m satisfies this.
autoencoderOptions.maxIter = 400;	  % Maximum number of iterations of L-BFGS to run 
autoencoderOptions.display = 'on';

if exist('opttheta.mat','file')==2
    
    load('opttheta.mat');

else
[opttheta, cost] = minFunc( @(p) sparseAutoencoderCost(p, ...
                                   inputSize, hiddenSize, ...
                                   lambda, sparsityParam, ...
                                   beta, unlabeledData), ...
                              theta, autoencoderOptions);                        
save('opttheta.mat','opttheta');
end


%% -----------------------------------------------------
                          
% Visualize weights
W1 = reshape(opttheta(1:hiddenSize * inputSize), hiddenSize, inputSize);
display_network(W1');

%%======================================================================
%% STEP 3: Extract Features from the Supervised Dataset
 

trainFeatures = feedForwardAutoencoder(opttheta, hiddenSize, inputSize, ...
                                       trainData);

testFeatures = feedForwardAutoencoder(opttheta, hiddenSize, inputSize, ...
                                       testData);

%%======================================================================
%% STEP 4: Train the softmax classifier
 

softmaxOptions.maxIter = 100;
lambdaSoftmax = 1e-4; % Weight decay parameter for Softmax
trainNumber = size(trainData,2);

% softmaxTrain 默认数据中已包含截距项
softmaxModel = softmaxTrain(hiddenSize+1, numLabels, lambdaSoftmax, [trainFeatures;ones(1,trainNumber)], trainLabels, softmaxOptions);  % learn by features
%softmaxModel = softmaxTrain(inputSize+1, numLabels, lambdaSoftmax, [trainData;ones(1,trainNumber)], trainLabels, softmaxOptions);  % learn by raw data


%% -----------------------------------------------------


%%======================================================================
%% STEP 5: Testing 

%% ----------------- YOUR CODE HERE ----------------------
% Compute Predictions on the test set (testFeatures) using softmaxPredict
% and softmaxModel
testNumber = size(testData,2);

% softmaxPredict 默认数据中已包含截距项
[pred] = softmaxPredict(softmaxModel, [testFeatures;ones(1,testNumber)]);  % predict by test features
%[pred] = softmaxPredict(softmaxModel, [testData;ones(1,testNumber)]);  % predict by test raw data


%% -----------------------------------------------------

% Classification Score
fprintf('Test Accuracy: %f%%\n', 100*mean(pred(:) == testLabels(:)));
up2112

3. Simulation conclusion

 

Guess you like

Origin blog.csdn.net/ccsss22/article/details/131058649