matlab2018a resnet50迁移学习

该函数用来获取数据集(data_resize是我的文件名,训练时修改即可)

function [dataTrain,dataTest] = merchData()

% unzip(fullfile(matlabroot,'examples','nnet','MerchData.zip'));

data = imageDatastore('data_resize',...
    'IncludeSubfolders',true,...
    'LabelSource','foldernames');

[dataTrain,dataTest] = splitEachLabel(data,0.7);

dataTrain = shuffle(dataTrain);

end

下面的代码为主函数

%该文档建立resnet50模型

 %resnet50
net=resnet50;
layer=net.Layers(1:end-3);


%读取数据
[merchImagesTrain,merchImagesTest] = merchData();
numClasses = numel(categories(merchImagesTrain.Labels))

%建立层之间的连接


layers = [
    layer
    fullyConnectedLayer(numClasses,'Name','fc3','WeightLearnRateFactor',1,'BiasLearnRateFactor',1)
    softmaxLayer('Name','fc3_softmax')
    classificationLayer('Name','ClassificationLayer_fc3')
    ];  


lgraph = layerGraph(layers);
figure;plot(lgraph)

%修改层连接

lgraph = removeLayers(lgraph,'res2a_branch1');
lgraph = removeLayers(lgraph,'bn2a_branch1');
lgraph = removeLayers(lgraph,'res3a_branch1');
lgraph = removeLayers(lgraph,'bn3a_branch1');
lgraph = removeLayers(lgraph,'res4a_branch1');
lgraph = removeLayers(lgraph,'bn4a_branch1');
lgraph = removeLayers(lgraph,'res5a_branch1');
lgraph = removeLayers(lgraph,'bn5a_branch1');
figure;plot(lgraph)


layers_1=lgraph.Layers;
lgraph_1 = layerGraph(layers_1);
figure;plot(lgraph_1)

%添加层

res2a_branch1 = convolution2dLayer(1,256,'Name','res2a_branch1','Stride',1);
bn2a_branch1 = batchNormalizationLayer('Name','bn2a_branch1');
-----------------------------------------------------------------------------------------------------------------------------------------------
res3a_branch1 = convolution2dLayer(1,512,'Name','res3a_branch1','Stride',2);
bn3a_branch1 = batchNormalizationLayer('Name','bn3a_branch1');

-------------------------------------------------------------------------------------------------------------------------------------------------

res4a_branch1 = convolution2dLayer(1,1024,'Name','res4a_branch1','Stride',2);
bn4a_branch1 = batchNormalizationLayer('Name','bn4a_branch1');

-------------------------------------------------------------------------------------------------------------------------------------------------

res5a_branch1 = convolution2dLayer(1,2048,'Name','res5a_branch1','Stride',2);
bn5a_branch1 = batchNormalizationLayer('Name','bn5a_branch1');


lgraph_1 = addLayers(lgraph_1,res2a_branch1);
lgraph_1 = addLayers(lgraph_1,bn2a_branch1);
lgraph_1 = addLayers(lgraph_1,res3a_branch1);
lgraph_1 = addLayers(lgraph_1,bn3a_branch1);
lgraph_1 = addLayers(lgraph_1,res4a_branch1);
lgraph_1 = addLayers(lgraph_1,bn4a_branch1);
lgraph_1 = addLayers(lgraph_1,res5a_branch1);
lgraph_1 = addLayers(lgraph_1,bn5a_branch1);

figure;plot(lgraph_1)

%修改连接

 lgraph_1 = connectLayers(lgraph_1,'max_pooling2d_1','res2a_branch1');
 lgraph_1 = connectLayers(lgraph_1,'res2a_branch1','bn2a_branch1');
 lgraph_1 = connectLayers(lgraph_1,'bn2a_branch1','add_1/in2');
%  ------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_4_relu','add_2/in2');
% -------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_7_relu','add_3/in2');
%  ----------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_10_relu','res3a_branch1');
 lgraph_1 = connectLayers(lgraph_1,'res3a_branch1','bn3a_branch1');
 lgraph_1 = connectLayers(lgraph_1,'bn3a_branch1','add_4/in2');
%  ------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_13_relu','add_5/in2');
% -------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_16_relu','add_6/in2');
 % -------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_19_relu','add_7/in2');
 %  ----------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_22_relu','res4a_branch1');
 lgraph_1 = connectLayers(lgraph_1,'res4a_branch1','bn4a_branch1');
 lgraph_1 = connectLayers(lgraph_1,'bn4a_branch1','add_8/in2');
%  ------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_25_relu','add_9/in2');
% -------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_28_relu','add_10/in2');
 % -------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_31_relu','add_11/in2');
 % -------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_34_relu','add_12/in2');
 % -------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_37_relu','add_13/in2');
  % -------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_40_relu','res5a_branch1');
 lgraph_1 = connectLayers(lgraph_1,'res5a_branch1','bn5a_branch1');
 lgraph_1 = connectLayers(lgraph_1,'bn5a_branch1','add_14/in2');
%  ------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_43_relu','add_15/in2');
% -------------------------------------------------------------------------------------
 lgraph_1 = connectLayers(lgraph_1,'activation_46_relu','add_16/in2');
 figure;plot(lgraph_1)


options = trainingOptions('sgdm',...
    'MiniBatchSize',5,...
    'MaxEpochs',10,...
    'InitialLearnRate',0.0001);
netTransfer = trainNetwork(merchImagesTrain, lgraph_1,options);
predictedLabels = classify(netTransfer,merchImagesTest);
testLabels = merchImagesTest.Labels;

accuracy = sum(predictedLabels==testLabels)/numel(predictedLabels)







猜你喜欢

转载自blog.csdn.net/qq_31442743/article/details/80175391