GoogLeNet Matlab toolbox 快速入门

GoogLeNet已经接受了超过一百万张图像的培训,可以将图像分为1000个对象类别(如键盘,咖啡杯,铅笔和许多动物)。 该网络已经为各种图像学习了丰富的特征表示。 网络将图像作为输入,然后输出图像中对象的标签以及每个对象类别的概率。迁移学习通常用于深度学习应用程序,可以使用预训练网络并将其作为学习新任务的起点。

下载GoogLeNet toolbox

在命令窗口输入 “googlenet”进入页面下载,如已经安装会出现以下字样:

 

ans = 

  DAGNetwork - 属性:

         Layers: [144×1 nnet.cnn.layer.Layer]
    Connections: [170×2 table]

如何想要对GoogLeNet有深入的了解,可以在电脑的matla安装路径下找到对应的toolbox的文件夹。例如:matlab2018\toolbox\nnet\cnn

利用迁移学习对GoogLeNet进行更改,下面是mathwork官网上的例子,可以进行参考。

%% 加载数据
% 解压缩并将新图像作为图像数据存储加载。这个非常小的数据集只包含75个图像。 将数据划分为训练和验证数据集。使用70%的图像进行训练,30%进行验证。
unzip('MerchData.zip');
imds = imageDatastore('MerchData', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');

%% 加载预训练网络
net = googlenet;

%% 从训练有素的网络中提取图层图并绘制图层图。
lgraph = layerGraph(net);
figure('Units','normalized','Position',[0.1 0.1 0.8 0.8]);
plot(lgraph)

net.Layers(1)

inputSize = net.Layers(1).InputSize;

lgraph = removeLayers(lgraph, {'loss3-classifier','prob','output'});

%% 替换最终图层
numClasses = numel(categories(imdsTrain.Labels));
newLayers = [
    fullyConnectedLayer(numClasses,'Name','fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10)
    softmaxLayer('Name','softmax')
    classificationLayer('Name','classoutput')];
lgraph = addLayers(lgraph,newLayers);

%%
lgraph = connectLayers(lgraph,'pool5-drop_7x7_s1','fc');

figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]);
plot(lgraph)
ylim([0,10])

%% 冻结初始图层
layers = lgraph.Layers;
connections = lgraph.Connections;

% edit(fullfile(matlabroot,'examples','nnet','main','freezeWeights.m'))
% edit(fullfile(matlabroot,'examples','nnet','main','createLgraphUsingConnections.m'))
 
layers(1:110) = freezeWeights(layers(1:110));
lgraph = createLgraphUsingConnections(layers,connections);

%% 训练网络
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

options = trainingOptions('sgdm', ...
    'MiniBatchSize',10, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',1e-4, ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',3, ...
    'ValidationPatience',Inf, ...
    'Verbose',false ,...
    'Plots','training-progress');

net = trainNetwork(augimdsTrain,lgraph,options);


%% 对验证图像进行分类
[YPred,probs] = classify(net,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels);

猜你喜欢

转载自blog.csdn.net/weixin_40249164/article/details/82629928
今日推荐