分类预测 | MATLAB实现CNN-LSTM-Attention多输入分类预测

分类预测 | MATLAB实现CNN-LSTM-Attention多输入分类预测

分类效果

1
2
3
4
5
6

基本介绍

MATLAB实现CNN-LSTM-Attention多输入分类预测,CNN-LSTM结合注意力机制多输入分类预测。

模型描述

Matlab实现CNN-LSTM-Attention多变量分类预测
1.data为数据集,格式为excel,12个输入特征,输出四个类别;
2.MainCNN_LSTM_AttentionNC.m为主程序文件,运行即可;

注意程序和数据放在一个文件夹,运行环境为Matlab200b及以上。
4.注意力机制模块:
SEBlock(Squeeze-and-Excitation Block)是一种聚焦于通道维度而提出一种新的结构单元,为模型添加了通道注意力机制,该机制通过添加各个特征通道的重要程度的权重,针对不同的任务增强或者抑制对应的通道,以此来提取有用的特征。该模块的内部操作流程如图,总体分为三步:首先是Squeeze 压缩操作,对空间维度的特征进行压缩,保持特征通道数量不变。融合全局信息即全局池化,并将每个二维特征通道转换为实数。实数计算公式如公式所示。该实数由k个通道得到的特征之和除以空间维度的值而得,空间维数为H*W。其次是Excitation激励操作,它由两层全连接层和Sigmoid函数组成。如公式所示,s为激励操作的输出,σ为激活函数sigmoid,W2和W1分别是两个完全连接层的相应参数,δ是激活函数ReLU,对特征先降维再升维。最后是Reweight操作,对之前的输入特征进行逐通道加权,完成原始特征在各通道上的重新分配。

1
2

程序设计

  • 完整程序和数据获取方式1:同等价值程序兑换;
  • 完整程序和数据获取方式2:私信博主获取。
tempLayers = [
    reluLayer("Name", "relu_1")                                        % 激活层
    convolution2dLayer([3, 1], 64, "Name", "conv_2")                   % 卷积层 卷积核[3, 1] 步长[1, 1] 通道数 64
    reluLayer("Name", "relu_2")];                                      % 激活层
lgraph = addLayers(lgraph, tempLayers);                                % 将上述网络结构加入空白结构中

tempLayers = [
    globalAveragePooling2dLayer("Name", "gapool")                      % 全局平均池化层
    fullyConnectedLayer(16, "Name", "fc_2")                            % SE注意力机制,通道数的1 / 4
    reluLayer("Name", "relu_3")                                        % 激活层
    fullyConnectedLayer(64, "Name", "fc_3")                            % SE注意力机制,数目和通道数相同
    sigmoidLayer("Name", "sigmoid")];                                  % 激活层
lgraph = addLayers(lgraph, tempLayers);                                % 将上述网络结构加入空白结构中

tempLayers = multiplicationLayer(2, "Name", "multiplication");         % 点乘的注意力
lgraph = addLayers(lgraph, tempLayers);                                % 将上述网络结构加入空白结构中
%%  混淆矩阵

if flag_conusion == 1

    figure
    cm = confusionchart(T_train, T_sim1);
    cm.Title = 'Confusion Matrix for Train Data';
    cm.ColumnSummary = 'column-normalized';
    cm.RowSummary = 'row-normalized';
    
    figure
    cm = confusionchart(T_test, T_sim2);
    cm.Title = 'Confusion Matrix for Test Data';
    cm.ColumnSummary = 'column-normalized';
    cm.RowSummary = 'row-normalized';
end

参考资料

[1] https://blog.csdn.net/kjm13182345320/article/details/129943065?spm=1001.2014.3001.5501
[2] https://blog.csdn.net/kjm13182345320/article/details/129919734?spm=1001.2014.3001.5501

猜你喜欢

转载自blog.csdn.net/kjm13182345320/article/details/129973597