该函数用来获取数据集(data_resize是我的文件名,训练时修改即可)
function [dataTrain,dataTest] = merchData() % unzip(fullfile(matlabroot,'examples','nnet','MerchData.zip')); data = imageDatastore('data_resize',... 'IncludeSubfolders',true,... 'LabelSource','foldernames'); [dataTrain,dataTest] = splitEachLabel(data,0.7); dataTrain = shuffle(dataTrain); end
下面的代码为主函数
%该文档建立resnet50模型 %resnet50 net=resnet50; layer=net.Layers(1:end-3); %读取数据 [merchImagesTrain,merchImagesTest] = merchData(); numClasses = numel(categories(merchImagesTrain.Labels)) %建立层之间的连接 layers = [ layer fullyConnectedLayer(numClasses,'Name','fc3','WeightLearnRateFactor',1,'BiasLearnRateFactor',1) softmaxLayer('Name','fc3_softmax') classificationLayer('Name','ClassificationLayer_fc3') ]; lgraph = layerGraph(layers); figure;plot(lgraph) %修改层连接 lgraph = removeLayers(lgraph,'res2a_branch1'); lgraph = removeLayers(lgraph,'bn2a_branch1'); lgraph = removeLayers(lgraph,'res3a_branch1'); lgraph = removeLayers(lgraph,'bn3a_branch1'); lgraph = removeLayers(lgraph,'res4a_branch1'); lgraph = removeLayers(lgraph,'bn4a_branch1'); lgraph = removeLayers(lgraph,'res5a_branch1'); lgraph = removeLayers(lgraph,'bn5a_branch1'); figure;plot(lgraph) layers_1=lgraph.Layers; lgraph_1 = layerGraph(layers_1); figure;plot(lgraph_1) %添加层 res2a_branch1 = convolution2dLayer(1,256,'Name','res2a_branch1','Stride',1); bn2a_branch1 = batchNormalizationLayer('Name','bn2a_branch1'); ----------------------------------------------------------------------------------------------------------------------------------------------- res3a_branch1 = convolution2dLayer(1,512,'Name','res3a_branch1','Stride',2); bn3a_branch1 = batchNormalizationLayer('Name','bn3a_branch1'); ------------------------------------------------------------------------------------------------------------------------------------------------- res4a_branch1 = convolution2dLayer(1,1024,'Name','res4a_branch1','Stride',2); bn4a_branch1 = batchNormalizationLayer('Name','bn4a_branch1'); ------------------------------------------------------------------------------------------------------------------------------------------------- res5a_branch1 = convolution2dLayer(1,2048,'Name','res5a_branch1','Stride',2); bn5a_branch1 = batchNormalizationLayer('Name','bn5a_branch1'); lgraph_1 = addLayers(lgraph_1,res2a_branch1); lgraph_1 = addLayers(lgraph_1,bn2a_branch1); lgraph_1 = addLayers(lgraph_1,res3a_branch1); lgraph_1 = addLayers(lgraph_1,bn3a_branch1); lgraph_1 = addLayers(lgraph_1,res4a_branch1); lgraph_1 = addLayers(lgraph_1,bn4a_branch1); lgraph_1 = addLayers(lgraph_1,res5a_branch1); lgraph_1 = addLayers(lgraph_1,bn5a_branch1); figure;plot(lgraph_1) %修改连接 lgraph_1 = connectLayers(lgraph_1,'max_pooling2d_1','res2a_branch1'); lgraph_1 = connectLayers(lgraph_1,'res2a_branch1','bn2a_branch1'); lgraph_1 = connectLayers(lgraph_1,'bn2a_branch1','add_1/in2'); % ------------------------------------------------------------------------------------ lgraph_1 = connectLayers(lgraph_1,'activation_4_relu','add_2/in2'); % ------------------------------------------------------------------------------------- lgraph_1 = connectLayers(lgraph_1,'activation_7_relu','add_3/in2'); % ---------------------------------------------------------------------------------- lgraph_1 = connectLayers(lgraph_1,'activation_10_relu','res3a_branch1'); lgraph_1 = connectLayers(lgraph_1,'res3a_branch1','bn3a_branch1'); lgraph_1 = connectLayers(lgraph_1,'bn3a_branch1','add_4/in2'); % ------------------------------------------------------------------------------------ lgraph_1 = connectLayers(lgraph_1,'activation_13_relu','add_5/in2'); % ------------------------------------------------------------------------------------- lgraph_1 = connectLayers(lgraph_1,'activation_16_relu','add_6/in2'); % ------------------------------------------------------------------------------------- lgraph_1 = connectLayers(lgraph_1,'activation_19_relu','add_7/in2'); % ---------------------------------------------------------------------------------- lgraph_1 = connectLayers(lgraph_1,'activation_22_relu','res4a_branch1'); lgraph_1 = connectLayers(lgraph_1,'res4a_branch1','bn4a_branch1'); lgraph_1 = connectLayers(lgraph_1,'bn4a_branch1','add_8/in2'); % ------------------------------------------------------------------------------------ lgraph_1 = connectLayers(lgraph_1,'activation_25_relu','add_9/in2'); % ------------------------------------------------------------------------------------- lgraph_1 = connectLayers(lgraph_1,'activation_28_relu','add_10/in2'); % ------------------------------------------------------------------------------------- lgraph_1 = connectLayers(lgraph_1,'activation_31_relu','add_11/in2'); % ------------------------------------------------------------------------------------- lgraph_1 = connectLayers(lgraph_1,'activation_34_relu','add_12/in2'); % ------------------------------------------------------------------------------------- lgraph_1 = connectLayers(lgraph_1,'activation_37_relu','add_13/in2'); % ------------------------------------------------------------------------------------- lgraph_1 = connectLayers(lgraph_1,'activation_40_relu','res5a_branch1'); lgraph_1 = connectLayers(lgraph_1,'res5a_branch1','bn5a_branch1'); lgraph_1 = connectLayers(lgraph_1,'bn5a_branch1','add_14/in2'); % ------------------------------------------------------------------------------------ lgraph_1 = connectLayers(lgraph_1,'activation_43_relu','add_15/in2'); % ------------------------------------------------------------------------------------- lgraph_1 = connectLayers(lgraph_1,'activation_46_relu','add_16/in2'); figure;plot(lgraph_1) options = trainingOptions('sgdm',... 'MiniBatchSize',5,... 'MaxEpochs',10,... 'InitialLearnRate',0.0001); netTransfer = trainNetwork(merchImagesTrain, lgraph_1,options); predictedLabels = classify(netTransfer,merchImagesTest); testLabels = merchImagesTest.Labels; accuracy = sum(predictedLabels==testLabels)/numel(predictedLabels)