BP神经网络做分类+隐含层节点确定+红酒数据为例

网上用BP神经网络做预测的代码有很多,但是做分类的很少,(虽然都是一个道理),但是预测的代码下载下来还得动手修改,对于想直接复制粘贴的友友们很不友好。想用分类代码的直接来我这里复制粘贴即可,跑不通的欢迎来dao我。

废话不多说,上干货了!

老规矩,先上结果图!

以上两个图片道理相同,只不过展现形式不一致而已。

红酒数据:178×13列,再加一列标签。选取百分之70作为训练集,百分之30作为测试集。训练结果如上图,准确率达到了98%以上,当然这只是作者在运行了n多次后,选出的比较好的一次结果,不见得每次都这么高哈!大家运行的结果如果没这么高,请不要刀我。。。

好啦,上代码!

该段代码包含了,数据归一化处理,数据维数变更,以准确率为标准,求最佳隐含层节点数,画热力图,画结果图等,大家可以自行复制粘贴,选取自己想要的运行!哪里看不懂的欢迎评论区留言!  

对了,plotConfMat.m 是一个引用的函数哦,别忘记copy这个!

close all
warning off
%% 数据读取
clc
clear
load Wine
%% 数据载入
data=Wine;
data=data(randperm(size(data,1)),:);    %此行代码用于打乱原始样本,使训练集测试集随机被抽取,有助于更新预测结果。
input=data(:,2:end);
output1 =data(:,1);
%把输出从1维变成3维
for i=1:size(data,1)
    switch output1(i)
        case 1
            output(i,:)=[1 0 0];
        case 2
            output(i,:)=[0 1 0];
        case 3
            output(i,:)=[0 0 1];
     end
end

%% 选取训练数据和测试数据
m=fix(size(data,1)*0.7);    %训练的样本数目
input_train=input(1:m,:)';
output_train=output(1:m,:)';
input_test=input(m+1:end,:)';
output_test=output(m+1:end,:)';
%% 数据归一化
[inputn,inputps]=mapminmax(input_train,0,1);
% [outputn,outputps]=mapminmax(output_train);
inputn_test=mapminmax('apply',input_test,inputps);

%% 获取输入层节点、输出层节点个数
inputnum=size(input,2);
outputnum=size(output,2);
disp('/')
disp('神经网络结构...')
disp(['输入层的节点数为:',num2str(inputnum)])
disp(['输出层的节点数为:',num2str(outputnum)])
disp(' ')
disp('隐含层节点的确定过程...')

%确定隐含层节点个数
%采用经验公式hiddennum=sqrt(m+n)+a,m为输入层节点个数,n为输出层节点个数,a一般取为1-10之间的整数
ACC=0; %初始化最小误差
for hiddennum=fix(sqrt(inputnum+outputnum))+1:fix(sqrt(inputnum+outputnum))+10
    
    %构建网络
    net=newff(inputn,output_train,hiddennum);
    % 网络参数
    net.trainParam.epochs=1000;         % 训练次数
    net.trainParam.lr=0.01;                   % 学习速率
    net.trainParam.goal=0.000001;        % 训练目标最小误差
    % 网络训练
    net=train(net,inputn,output_train);
    an0=sim(net,inputn);  %仿真结果
    predict_label=zeros(1,size(an0,2));
    for i=1:size(an0,2)
        predict_label(i)=find(an0(:,i)==max(an0(:,i)));
    end
    outputt=zeros(1,size(output_train,2));
    for i=1:size(output_train,2)
        outputt(i)=find(output_train(:,i)==max(output_train(:,i)));
    end
%     fprintf('test is over and plot begining……\n');
    accuracy=sum(outputt==predict_label)/length(predict_label);   %计算预测的确率
    disp(['当隐含层节点为:',num2str(hiddennum),',准确率为:',num2str(accuracy*100),'%'])
    %更新最佳的隐含层节点
    if accuracy>ACC
        ACC=accuracy;
        hiddennum_best=hiddennum;
    end
end
disp(['最佳的隐含层节点数为:',num2str(hiddennum_best),',相应的准确率为:',num2str(ACC*100),'%'])

%% 构建最佳隐含层节点的BP神经网络
disp(' ')
disp('标准的BP神经网络:')
net0=newff(inputn,output_train,hiddennum_best,{'tansig','purelin'},'trainlm');% 建立模型
%网络参数配置
net0.trainParam.epochs=1000;            % 训练次数,这里设置为1000次
net0.trainParam.lr=0.01;                % 学习速率,这里设置为0.01
net0.trainParam.goal=0.00001;           % 训练目标最小误差,这里设置为0.0001
net0.trainParam.show=25;                % 显示频率,这里设置为每训练25次显示一次
net0.trainParam.mc=0.01;                % 动量因子
net0.trainParam.min_grad=1e-6;          % 最小性能梯度
net0.trainParam.max_fail=6;             % 最高失败次数

%开始训练
net0=train(net0,inputn,output_train);
%预测
an0=sim(net0,inputn_test); %用训练好的模型进行仿真
predict_label=zeros(1,size(an0,2));
for i=1:size(an0,2)
    predict_label(i)=find(an0(:,i)==max(an0(:,i)));
end
outputt=zeros(1,size(output_test,2));
for i=1:size(output_test,2)
    outputt(i)=find(output_test(:,i)==max(output_test(:,i)));
end
fprintf('test is over and plot begining……\n');
accuracy=sum(outputt==predict_label)/length(predict_label);   %计算预测的确率
disp(['准确率:',num2str(accuracy*100),'%'])

 %% 作图
figure
stem(1:length(predict_label),predict_label,'b^')
hold on
stem(1:length(predict_label),outputt,'r*')
xlim([0 45])
legend('预测类别','真实类别','NorthWest')
title({'BP神经网络的预测效果',['测试集正确率 = ',num2str(accuracy*100),' %']})
xlabel('预测样本编号')
ylabel('分类结果')
set(gca,'fontsize',12)
%输出准确率
disp('---------------------------测试准确率-------------------------')
 disp(['准确率:',num2str(accuracy*100),'%'])
%% 画方框图
confMat = confusionmat(outputt,predict_label);  %output_test是真实值标签
figure;
set(gcf,'unit','centimeters','position',[15 5 13 9])
plotConfMat(confMat.');  
xlabel('Predicted label')
ylabel('Real label')
hold off

plotConfMatm.m函数如下:

function plotConfMat(varargin)
%PLOTCONFMAT plots the confusion matrix with colorscale, absolute numbers
%   and precision normalized percentages
%
%   usage: 
%   PLOTCONFMAT(confmat) plots the confmat with integers 1 to n as class labels
%   PLOTCONFMAT(confmat, labels) plots the confmat with the specified labels
%
%   Vahe Tshitoyan
%   20/08/2017
%
%   Arguments
%   confmat:            a square confusion matrix
%   labels (optional):  vector of class labels

% number of arguments
switch (nargin)
    case 0
       confmat = 1;
       labels = {'1'};
    case 1
       confmat = varargin{1};
       labels = 1:size(confmat, 1);
    otherwise
       confmat = varargin{1};
       labels = varargin{2};
end

confmat(isnan(confmat))=0; % in case there are NaN elements
numlabels = size(confmat, 1); % number of labels

% calculate the percentage accuracies
confpercent = 100*confmat./repmat(sum(confmat, 1),numlabels,1);

% plotting the colors
imagesc(confpercent);
title(sprintf('Accuracy: %.2f%%', 100*trace(confmat)/sum(confmat(:))));
ylabel('Output Class (Truth)'); xlabel('Target Class (Predicted)');

% set the colormap
colormap(flipud(gray));

% Create strings from the matrix values and remove spaces
textStrings = num2str([confpercent(:), confmat(:)], '%.1f%%\n%d\n');
textStrings = strtrim(cellstr(textStrings));

% Create x and y coordinates for the strings and plot them
[x,y] = meshgrid(1:numlabels);
hStrings = text(x(:),y(:),textStrings(:), ...
    'HorizontalAlignment','center','fontsize',13);

% Get the middle value of the color range
midValue = mean(get(gca,'CLim'));

% Choose white or black for the text color of the strings so
% they can be easily seen over the background color
textColors = repmat(confpercent(:) > midValue,1,3);
set(hStrings,{'Color'},num2cell(textColors,2));

% Setting the axis labels
set(gca,'XTick',1:numlabels,...
    'XTickLabel',labels,...
    'YTick',1:numlabels,...
    'YTickLabel',labels,...
    'TickLength',[0 0]);
a1 = get(gca,'XTickLabel');
set(gca,'XTickLabel',a1,'fontsize',13)

end

最后祝愿各位小伙伴,所有代码永不出错! 

觉着不错的给博主留个小赞吧!您的一个小赞就是博主更新的动力!谢谢!

猜你喜欢

转载自blog.csdn.net/woaipythonmeme/article/details/128819916
今日推荐