GoogLeNet已经接受了超过一百万张图像的培训,可以将图像分为1000个对象类别(如键盘,咖啡杯,铅笔和许多动物)。 该网络已经为各种图像学习了丰富的特征表示。 网络将图像作为输入,然后输出图像中对象的标签以及每个对象类别的概率。迁移学习通常用于深度学习应用程序,可以使用预训练网络并将其作为学习新任务的起点。
下载GoogLeNet toolbox
在命令窗口输入 “googlenet”进入页面下载,如已经安装会出现以下字样:
ans =
DAGNetwork - 属性:
Layers: [144×1 nnet.cnn.layer.Layer]
Connections: [170×2 table]
如何想要对GoogLeNet有深入的了解,可以在电脑的matla安装路径下找到对应的toolbox的文件夹。例如:matlab2018\toolbox\nnet\cnn
利用迁移学习对GoogLeNet进行更改,下面是mathwork官网上的例子,可以进行参考。
%% 加载数据
% 解压缩并将新图像作为图像数据存储加载。这个非常小的数据集只包含75个图像。 将数据划分为训练和验证数据集。使用70%的图像进行训练,30%进行验证。
unzip('MerchData.zip');
imds = imageDatastore('MerchData', ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
%% 加载预训练网络
net = googlenet;
%% 从训练有素的网络中提取图层图并绘制图层图。
lgraph = layerGraph(net);
figure('Units','normalized','Position',[0.1 0.1 0.8 0.8]);
plot(lgraph)
net.Layers(1)
inputSize = net.Layers(1).InputSize;
lgraph = removeLayers(lgraph, {'loss3-classifier','prob','output'});
%% 替换最终图层
numClasses = numel(categories(imdsTrain.Labels));
newLayers = [
fullyConnectedLayer(numClasses,'Name','fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10)
softmaxLayer('Name','softmax')
classificationLayer('Name','classoutput')];
lgraph = addLayers(lgraph,newLayers);
%%
lgraph = connectLayers(lgraph,'pool5-drop_7x7_s1','fc');
figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]);
plot(lgraph)
ylim([0,10])
%% 冻结初始图层
layers = lgraph.Layers;
connections = lgraph.Connections;
% edit(fullfile(matlabroot,'examples','nnet','main','freezeWeights.m'))
% edit(fullfile(matlabroot,'examples','nnet','main','createLgraphUsingConnections.m'))
layers(1:110) = freezeWeights(layers(1:110));
lgraph = createLgraphUsingConnections(layers,connections);
%% 训练网络
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'ValidationPatience',Inf, ...
'Verbose',false ,...
'Plots','training-progress');
net = trainNetwork(augimdsTrain,lgraph,options);
%% 对验证图像进行分类
[YPred,probs] = classify(net,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels);