Sparrow Optimizing CNN Hyperparameters for Regression MATLAB

         In the construction of the CNN model, a large number of hyperparameters are involved, such as: learning rate, number of training times, batchsize, convolution kernel size and number of convolution kernels (number of feature maps) of each convolution layer, nodes of the fully connected layer Count and so on. If you choose directly, it is difficult to select a set of satisfactory parameters, so the optimization algorithm is used for hyperparameter optimization. Compared with multiple attempts, the optimization algorithm will follow its own set of strategies for optimization selection.

        Based on this idea, this paper uses the sparrow optimization algorithm to optimize the above nine hyperparameters of CNN.

1. The principle of sparrow optimization algorithm

       Sparrow optimization was proposed in 2020, the specific principle: Click here for the principle

2. The principle of sparrow optimization CNN

        In general, when optimizing parameters, it is to set a set of hyperparameters for the CNN network, then train and verify, and take the model with the highest accuracy of the verification set (this model can be considered to have the optimal hyperparameters). In fact, the same is true of optimization algorithms. They are constantly generating new hyperparameter combinations, and then use this set of hyperparameters to build a CNN network, train and verify it. However, the optimization algorithm has its own learning rules. We optimize the CNN hyperparameters, that is, let SSA always find the set of hyperparameters that can maximize the accuracy of the verification set.

3. Code implementation:

        The data adopts multiple input and single output. In my excel sheet, the last column is output, and the first few columns are input. Of course, you can also change to multiple output format, so that you can write your own input and output correctly.

        3.1 CNN model

clc;clear;close all;rng(0)
%% 数据的提取
data=xlsread('data.xlsx');
input=data(:,1:end-1);%x
output=data(:,end);%y

%% 数据处理
n=randperm(size(input,1));
m=round(size(input,1)*0.7);%随机70%作为训练集 其余30%作为测试集
train_x=input(n(1:m),:);
train_y=output(n(1:m),:);
test_x=input(n(m+1:end),:);
test_y=output(n(m+1:end),:);

% 归一化或者标准化
method=@mapminmax;
% method=@mapstd;
[train_x,train_ps]=method(train_x');
test_x=method('apply',test_x',train_ps);
[train_y,output_ps]=method(train_y');
test_y=method('apply',test_y',output_ps);

feature=size(train_x,1);
num_train=size(train_x,2);
num_test=size(test_x,2);
trainD=reshape(train_x,[feature,1,1,num_train]);
testD=reshape(test_x,[feature,1,1,num_test]);
targetD = train_y';
targetD_test  = test_y';

%% 
layers = [
    imageInputLayer([size(trainD,1) size(trainD,2) size(trainD,3)]) % 输入
    convolution2dLayer(3,4,'Stride',1,'Padding','same')%核3*1 数量4 步长1 填充为same
    reluLayer%relu激活
    convolution2dLayer(3,8,'Stride',1,'Padding','same')%核3*1 数量8 步长1 填充为same
    reluLayer%relu激活
    fullyConnectedLayer(20) % 全连接层1 20个神经元
    reluLayer
    fullyConnectedLayer(20) % 全连接层2 20个神经元
    reluLayer
    fullyConnectedLayer(size(targetD,2)) %输出层
    regressionLayer];
options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',30, ...
    'MiniBatchSize',16, ...
    'InitialLearnRate',0.01, ...
    'GradientThreshold',1, ...
    'shuffle','every-epoch',...
    'Verbose',false);
train_again=1;% 为1就代码重新训练模型,为0就是调用训练好的网络
if train_again==1
    [net,traininfo] = trainNetwork(trainD,targetD,layers,options);
    save result/cnn_net net traininfo
else
    load result/cnn_net
end
figure;
plot(traininfo.TrainingLoss,'b')
hold on;grid on
ylabel('损失')
xlabel('训练次数')
title('CNN')


%% 结果评价
YPred = predict(net,testD);YPred=double(YPred);

% 反归一化
predict_value=method('reverse',YPred',output_ps);predict_value=double(predict_value);
true_value=method('reverse',targetD_test',output_ps);true_value=double(true_value);

save result/cnn_result predict_value true_value
%%
figure
plot(true_value,'-*','linewidth',3)
hold on
plot(predict_value,'-s','linewidth',3)
legend('实际值','预测值')
title('CNN')
grid on

result(true_value,predict_value)

The loss curve and test set results are shown in the figure:

 The result is as follows: 

R-square coefficient of determination (R2): 0.79945
Mean relative error (MPE): 0.20031
Mean absolute percentage error (MAPE): 20.0306%
Mean absolute error (MAE): 0.32544
Root mean square error (RMSE): 0.47113

        3.2 SSA optimized CNN

clc;clear;close all;format compact;rng(0)

%% 数据的提取
data=xlsread('data.xlsx');
input=data(:,1:end-1);%x
output=data(:,end);%y
%% 数据处理
n=randperm(size(input,1));
m=round(size(input,1)*0.7);%随机70%作为训练集 其余30%作为测试集
train_x=input(n(1:m),:);
train_y=output(n(1:m),:);
test_x=input(n(m+1:end),:);
test_y=output(n(m+1:end),:);

% 归一化或者标准化
method=@mapminmax;
% method=@mapstd;
[train_x,train_ps]=method(train_x');
test_x=method('apply',test_x',train_ps);
[train_y,output_ps]=method(train_y');
test_y=method('apply',test_y',output_ps);

feature=size(train_x,1);
num_train=size(train_x,2);
num_test=size(test_x,2);
trainD=reshape(train_x,[feature,1,1,num_train]);
testD=reshape(test_x,[feature,1,1,num_test]);
targetD = train_y';
targetD_test  = test_y';


%% SSA优化CNN的超参数
%一共有9个参数需要优化,分别是学习率、迭代次数、batchsize、第一层卷积层的核大小、和数量、第2层卷积层的核大小、和数量,以及两个全连接层的神经元数量
optimaztion=1;  
if optimaztion==1
    [x,trace]=ssa_cnn(trainD,targetD,testD,targetD_test);
    save result/ssa_result x trace
else
    load result/ssa_result
end
%%

figure
plot(trace)
title('适应度曲线')
xlabel('优化次数')
ylabel('适应度值')

disp('优化后的各超参数')
lr=x(1)%学习率
iter=x(2)%迭代次数
minibatch=x(3)%batchsize 
kernel1_size=x(4)
kernel1_num=x(5)%第一层卷积层的核大小
kernel2_size=x(6)
kernel2_num=x(7)%第2层卷积层的核大小
fc1_num=x(8)
fc2_num=x(9)%两个全连接层的神经元数量

%% 利用寻优得到参数重新训练CNN与预测
rng(0)
layers = [
    imageInputLayer([size(trainD,1) size(trainD,2) size(trainD,3)])
    convolution2dLayer(kernel1_size,kernel1_num,'Stride',1,'Padding','same')
    reluLayer
    convolution2dLayer(kernel2_size,kernel2_num,'Stride',1,'Padding','same')
    reluLayer
    fullyConnectedLayer(fc1_num)
    reluLayer
    fullyConnectedLayer(fc2_num)
    reluLayer
    fullyConnectedLayer(size(targetD,2))
    regressionLayer];
options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',iter, ...
    'MiniBatchSize',minibatch, ...
    'InitialLearnRate',lr, ...
    'GradientThreshold',1, ...
    'Verbose',false);

train_again=1;% 为1就重新训练模型,为0就是调用训练好的网络  load options_data1600.mat  ,changed!must retrain   train_again=0;% 为1就重新训练模型,为0就是调用训练好的网络
if train_again==1
    [net,traininfo] = trainNetwork(trainD,targetD,layers,options);
    save result/ssacnn_net net traininfo
else
    load result/ssacnn_net
end

figure;
plot(traininfo.TrainingLoss,'b')
hold on;grid on
ylabel('损失')
xlabel('训练次数')
title('SSA-CNN')


%% 结果评价
YPred = predict(net,testD);YPred=double(YPred);

% 反归一化
predict_value=method('reverse',YPred',output_ps);predict_value=double(predict_value);
true_value=method('reverse',targetD_test',output_ps);true_value=double(true_value);

save result/ssa_cnn_result predict_value true_value
%%
figure
plot(true_value,'-*','linewidth',3)
hold on
plot(predict_value,'-s','linewidth',3)
legend('实际值','预测值')
title('SSA-CNN')
grid on

result(true_value,predict_value)


        The fitness function is to minimize the mean square error between the true value and the predicted value of the test set, and the purpose is to find a set of hyperparameters so that the mean square error of the network is the lowest. So the fitness curve is a descending curve

 

The result is as follows:

R-square coefficient of determination (R2): 0.86547
Mean relative error (MPE): 0.114
Mean absolute percentage error (MAPE): 11.3997%
Mean absolute error (MAE): 0.17238
Root mean square error (RMSE): 0.26745

It can be seen that the test set indicators have improved significantly. 

4. Code

See the code in the comment area, and more

1. MATLAB sparrow optimizes CNN hyperparameter regression

2. MATLAB gray wolf optimizes CNN hyperparameter regression

3.MATLAB whale optimizes CNN hyperparameter regression

4. MATLAB sparrow optimizes CNN hyperparameter classification

Guess you like

Origin blog.csdn.net/qq_41043389/article/details/127539127