分类预测 | MATLAB实现基于Attention-LSTM的数据分类预测多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)

分类预测 | MATLAB实现基于Attention-LSTM的数据分类预测多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)

效果一览

1
2

1
2

基本介绍

分类预测 | MATLAB实现基于Attention-LSTM的数据分类预测多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)

程序设计

  • 完整程序和数据私信博主回复:Attention-LSTM的数据分类预测多特征分类预测
% 需要学习的参数
lstmweight = params.lstm.weights;
lstmrecurrentWeights = params.lstm.recurrentWeights;
lstmbias = params.lstm.bias;
% 不同批次间传递的参数(这里假设每一轮epoch中,不同Batch间的state是传递的,但不学习;
h0 = state.lstm.h0;
c0 = state.lstm.c0;
[Lstm_Y,h0,c0] = lstm(Train_X,h0,c0,lstmweight,lstmrecurrentWeights,lstmbias);

Htt = dlarray(Lstm_Y(:,:,1:end-1),'SBSC');    %转变成CNN输入格式,’SS为

%% Attention
Attentionweight  = params.attention.weight; % 计算得分权重
Att = dlarray(squeeze(sum(CnnHttAtt .* dlarray(Attentionweight,'SC'),2)),'SBC'); %'C'维度为cnn卷积后的每一行
Ht = Lstm_Y(:,:,end);       % 参考向量

HtAfter = dlarray(repmat(Ht,[1,1,50]),'SBC');
f = squeeze(sum(HtAfter.*Att,1));
socre = sigmoid(f);                   % 计算得分'CB'
socre = dlarray(repmat(socre,[1,1,6]),'CBS'); 

% 组成Vt
CnnAfterRow = dlarray(squeeze(CnnHtt),'CSB');    % 满足与socre维度一致
Vt = sum(CnnAfterRow .*socre,2);
Vt = squeeze(Vt);



%% Attention输出
weight1 = params.attenout.weight1;
bias1 = params.attenout.bias1;
weight2 = params.attenout.weight2;
bias2 = params.attenout.bias2;
Hthat = fullyconnect(Vt,weight1,bias1) + fullyconnect(Ht,weight2,bias2);

%% 全连接层前置层(降维)
LastWeight = params.fullyconnect.weight1;
LastBias = params.fullyconnect.bias1 ;
FullyconnectInput = fullyconnect(Hthat,LastWeight,LastBias);
FullyconnectInput = relu(FullyconnectInput);

参考资料

[1] https://blog.csdn.net/kjm13182345320/article/details/128163536?spm=1001.2014.3001.5502
[2] https://blog.csdn.net/kjm13182345320/article/details/128151206?spm=1001.2014.3001.5502

猜你喜欢

转载自blog.csdn.net/kjm13182345320/article/details/131756054
今日推荐