matlab 畳み込みニューラル ネットワークで発生する問題 - gpu と cpu の結果が異なり、最終的な予測結果にも問題がある

公式ヘルプ ドキュメント プログラムを使用して、手書きの数字セット テストで画像を認識します。

コードは次のとおりです。

clc;clear;close all;
digitDatasetPath = fullfile('C:\Users\86151\Desktop\mnist_data_jpg\新建文件夹');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');
figure
numImages = 10000;
perm = randperm(numImages,20);
for i = 1:20
    subplot(4,5,i);
    imshow(imds.Files{
    
    perm(i)});
end
% % splitEachLabel 将 digitData 中的图像文件拆分为两个新的数据存储,imdsTrain 和 imdsTest。
numTrainingFiles = 750;
[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles,'randomize');

% % 定义卷积神经网络架构。
layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];
% % 将选项设置为具有动量的随机梯度下降的默认设置。 将最大 epoch 数设置为 20,并以 0.0001 的初始学习率开始训练。
options = trainingOptions('sgdm', ...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4, ...
    'Verbose',false, ...
    'ExecutionEnvironment','cpu',...
    'Plots','training-progress');
% % 训练网络
net = trainNetwork(imdsTrain,layers,options);
save('CNNshuzi','net');%保存训练好的神经网络
% % 在未用于训练网络的测试集上运行经过训练的网络,并预测图像标签(数字)。
YPred = classify(net,imdsTest);
YTest = imdsTest.Labels;
% % 计算精度。 准确率是测试数据中与分类匹配的真实标签数量与测试数据中图像数量的比值。
accuracy = sum(YPred == YTest)/numel(YTest)


これはそれぞれ GPU と CPU でトレーニングするプロセスです

ここに画像の説明を挿入

ここに画像の説明を挿入

最終的な予測精度もあり、精度は毎回 10% を超え、YPred のラベルはすべて 1 であるため、トレーニング中の検証の正解率は非常に高くなります。

ここに画像の説明を挿入

最後に、基本的に認識できるテストプログラムを作成しました



clc;clear
%%Load the train model
load('CNNshuzi','net');
%%See details of the architecture
net.Layers
%%Read the image to classify
[file,path]=uigetfile('*');
image=fullfile(path,file);
I=imresize(imread(image),[28,28]);
file


tic
%Adjust size of the image
sz=net.Layers(1).InputSize;
%I=I(1:sz(1),1:sz(2),1:sz(3));

%Classify the image 
label = classify(net,I)
%Show the image and the classification results
figure('Name','识别结果','NumberTitle','off');
imshow(I);

title(['\bf',label]),xlabel(['\bf',label]);


ここに画像の説明を挿入

2021 年 12 月 22 日は
GPU バージョンの問題です。matlab2021 を再ダウンロードした後、CPU トレーニングに問題はありません。
ここに画像の説明を挿入

おすすめ

転載: blog.csdn.net/qq_43666228/article/details/122074432