递归神经网络LSTM原理——结合实例MATLAB实现

最近正在看递归神经网络,看了网上很多博文,算是鱼龙混杂,并且基本都是使用Python实现,要不就是使用Matlab中的函数库等。对于使用Matlab的同学,甚为不方便。所以我将结合实例,使用matlab语言,完成递归神经网络程序的编写(LSTM)。本人菜鸡一枚,如有错误还望各路大神,指正教导。文章的问题和数据和我之前写的递归神经网络BPTT文章中一致,方便大家比较两种方法的差异,文章链接递归神经网络BPTT的MATLAB实现。另外,关于理论推导算法步骤,等我过几天有时间更新。
一、问题描述
问题描述
这里写图片描述
二、相关数据
相关数据
这里写图片描述
三、程序代码
LSTM_mian.m

%%% LSTM网络结合实例仿真
%%% 作者:xd.wp
%%% 时间:2016.10.08  12:06
%% 程序说明
%  1、数据为7天,四个时间点的空调功耗,用前三个推测第四个训练,依次类推。第七天作为检验
%  2、LSTM网络输入结点为12,输出结点为4个,隐藏结点18个

clear all;
clc;
%% 数据加载,并归一化处理
[train_data,test_data]=LSTM_data_process();
data_length=size(train_data,1);
data_num=size(train_data,2);
%% 网络参数初始化
% 结点数设置
input_num=12;
cell_num=18;
output_num=4;
% 网络中门的偏置
bias_input_gate=rand(1,cell_num);
bias_forget_gate=rand(1,cell_num);
bias_output_gate=rand(1,cell_num);
% ab=1.2;
% bias_input_gate=ones(1,cell_num)/ab;
% bias_forget_gate=ones(1,cell_num)/ab;
% bias_output_gate=ones(1,cell_num)/ab;
%网络权重初始化
ab=20;
weight_input_x=rand(input_num,cell_num)/ab;
weight_input_h=rand(output_num,cell_num)/ab;
weight_inputgate_x=rand(input_num,cell_num)/ab;
weight_inputgate_c=rand(cell_num,cell_num)/ab;
weight_forgetgate_x=rand(input_num,cell_num)/ab;
weight_forgetgate_c=rand(cell_num,cell_num)/ab;
weight_outputgate_x=rand(input_num,cell_num)/ab;
weight_outputgate_c=rand(cell_num,cell_num)/ab;

%hidden_output权重
weight_preh_h=rand(cell_num,output_num);

%网络状态初始化
cost_gate=1e-6;
h_state=rand(output_num,data_num);
cell_state=rand(cell_num,data_num);
%% 网络训练学习
for iter=1:3000
    yita=0.01;            %每次迭代权重调整比例
    for m=1:data_num
        %前馈部分
        if(m==1)
            gate=tanh(train_data(:,m)'*weight_input_x);
            input_gate_input=train_data(:,m)'*weight_inputgate_x+bias_input_gate;
            output_gate_input=train_data(:,m)'*weight_outputgate_x+bias_output_gate;
            for n=1:cell_num
                input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));
                output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));
            end
            forget_gate=zeros(1,cell_num);
            forget_gate_input=zeros(1,cell_num);
            cell_state(:,m)=(input_gate.*gate)';
        else
            gate=tanh(train_data(:,m)'*weight_input_x+h_state(:,m-1)'*weight_input_h);
            input_gate_input=train_data(:,m)'*weight_inputgate_x+cell_state(:,m-1)'*weight_inputgate_c+bias_input_gate;
            forget_gate_input=train_data(:,m)'*weight_forgetgate_x+cell_state(:,m-1)'*weight_forgetgate_c+bias_forget_gate;
            output_gate_input=train_data(:,m)'*weight_outputgate_x+cell_state(:,m-1)'*weight_outputgate_c+bias_output_gate;
            for n=1:cell_num
                input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));
                forget_gate(1,n)=1/(1+exp(-forget_gate_input(1,n)));
                output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));
            end
            cell_state(:,m)=(input_gate.*gate+cell_state(:,m-1)'.*forget_gate)';   
        end
        pre_h_state=tanh(cell_state(:,m)').*output_gate;
        h_state(:,m)=(pre_h_state*weight_preh_h)';
        %误差计算
        Error=h_state(:,m)-test_data(:,m);
        Error_Cost(1,iter)=sum(Error.^2);
        if(Error_Cost(1,iter)<cost_gate)
            flag=1;
            break;
        else
            [   weight_input_x,...
                weight_input_h,...
                weight_inputgate_x,...
                weight_inputgate_c,...
                weight_forgetgate_x,...
                weight_forgetgate_c,...
                weight_outputgate_x,...
                weight_outputgate_c,...
                weight_preh_h ]=LSTM_updata_weight(m,yita,Error,...
                                                   weight_input_x,...
                                                   weight_input_h,...
                                                   weight_inputgate_x,...
                                                   weight_inputgate_c,...
                                                   weight_forgetgate_x,...
                                                   weight_forgetgate_c,...
                                                   weight_outputgate_x,...
                                                   weight_outputgate_c,...
                                                   weight_preh_h,...
                                                   cell_state,h_state,...
                                                   input_gate,forget_gate,...
                                                   output_gate,gate,...
                                                   train_data,pre_h_state,...
                                                   input_gate_input,...
                                                   output_gate_input,...
                                                   forget_gate_input);

        end
    end
    if(Error_Cost(1,iter)<cost_gate)
        break;
    end
end
%% 绘制Error-Cost曲线图
% for n=1:1:iter
%     text(n,Error_Cost(1,n),'*');
%     axis([0,iter,0,1]);
%     title('Error-Cost曲线图');   
% end
for n=1:1:iter
    semilogy(n,Error_Cost(1,n),'*');
    hold on;
    title('Error-Cost曲线图');   
end
%% 使用第七天数据检验
%数据加载
test_final=[0.4557 0.4790 0.7019 0.8211 0.4601 0.4811 0.7101 0.8298 0.4612 0.4845 0.7188 0.8312]';
test_final=test_final/sqrt(sum(test_final.^2));
test_output=test_data(:,4);
%前馈
m=4;
gate=tanh(test_final'*weight_input_x+h_state(:,m-1)'*weight_input_h);
input_gate_input=test_final'*weight_inputgate_x+cell_state(:,m-1)'*weight_inputgate_c+bias_input_gate;
forget_gate_input=test_final'*weight_forgetgate_x+cell_state(:,m-1)'*weight_forgetgate_c+bias_forget_gate;
output_gate_input=test_final'*weight_outputgate_x+cell_state(:,m-1)'*weight_outputgate_c+bias_output_gate;
for n=1:cell_num
    input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));
    forget_gate(1,n)=1/(1+exp(-forget_gate_input(1,n)));
    output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));
end
cell_state_test=(input_gate.*gate+cell_state(:,m-1)'.*forget_gate)';
pre_h_state=tanh(cell_state_test').*output_gate;
h_state_test=(pre_h_state*weight_preh_h)'
test_output
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145

LSTM_data_process.m

function [train_data,test_data]=LSTM_data_process()
%% 数据加载并完成初始归一化
train_data_initial= [0.4413 0.4707 0.6953 0.8133 0.4379 0.4677 0.6981 0.8002 0.4517 0.4725 0.7006 0.8201;
                     0.4379 0.4677 0.6981 0.8002 0.4517 0.4725 0.7006 0.8201 0.4557 0.4790 0.7019 0.8211;
                     0.4517 0.4725 0.7006 0.8201 0.4557 0.4790 0.7019 0.8211 0.4601 0.4911 0.7101 0.8298]';
% train_data_initial=[ 0.4413 0.4707 0.6953 0.8133;
%                      0.4379 0.4677 0.6981 0.8002;
%                      0.4517 0.4725 0.7006 0.8201;
%                      0.4557 0.4790 0.7019 0.8211;
%                      0.4601 0.4811 0.7101 0.8298;
%                      0.4612 0.4845 0.7188 0.8312]';
test_data_initial=[0.4557 0.4790 0.7019 0.8211;
                   0.4612 0.4845 0.7188 0.8312;
                   0.4601 0.4811 0.7101 0.8298;
                   0.4615 0.4891 0.7201 0.8330]';

data_length=size(train_data_initial,1);            %每个样本的长度
data_num=size(train_data_initial,2);               %样本数目  

%%归一化过程
for n=1:data_num
    train_data(:,n)=train_data_initial(:,n)/sqrt(sum(train_data_initial(:,n).^2));  
end
for m=1:size(test_data_initial,2)
    test_data(:,m)=test_data_initial(:,m)/sqrt(sum(test_data_initial(:,m).^2));
end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

LSTM_updata_weight.m

function [   weight_input_x,weight_input_h,weight_inputgate_x,weight_inputgate_c,weight_forgetgate_x,weight_forgetgate_c,weight_outputgate_x,weight_outputgate_c,weight_preh_h ]=LSTM_updata_weight(n,yita,Error,...
                                                   weight_input_x, weight_input_h, weight_inputgate_x,weight_inputgate_c,weight_forgetgate_x,weight_forgetgate_c,weight_outputgate_x,weight_outputgate_c,weight_preh_h,...
                                                   cell_state,h_state,input_gate,forget_gate,output_gate,gate,train_data,pre_h_state,input_gate_input, output_gate_input,forget_gate_input)
%%% 权重更新函数
input_num=12;
cell_num=18;
output_num=4;
data_length=size(train_data,1);
data_num=size(train_data,2);
weight_preh_h_temp=weight_preh_h;

%% 更新weight_preh_h权重
for m=1:output_num
    delta_weight_preh_h_temp(:,m)=2*Error(m,1)*pre_h_state;
end
weight_preh_h_temp=weight_preh_h_temp-yita*delta_weight_preh_h_temp;

%% 更新weight_outputgate_x
for num=1:output_num
    for m=1:data_length
        delta_weight_outputgate_x(m,:)=(2*weight_preh_h(:,num)*Error(num,1).*tanh(cell_state(:,n)))'.*exp(-output_gate_input).*(output_gate.^2)*train_data(m,n);
    end
    weight_outputgate_x=weight_outputgate_x-yita*delta_weight_outputgate_x;
end
%% 更新weight_inputgate_x
for num=1:output_num
for m=1:data_length
    delta_weight_inputgate_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*train_data(m,n);
end
weight_inputgate_x=weight_inputgate_x-yita*delta_weight_inputgate_x;
end


if(n~=1)
    %% 更新weight_input_x
    temp=train_data(:,n)'*weight_input_x+h_state(:,n-1)'*weight_input_h;
    for num=1:output_num
    for m=1:data_length
        delta_weight_input_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train_data(m,n);
    end
    weight_input_x=weight_input_x-yita*delta_weight_input_x;
    end
    %% 更新weight_forgetgate_x
    for num=1:output_num
    for m=1:data_length
        delta_weight_forgetgate_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*train_data(m,n);
    end
    weight_forgetgate_x=weight_forgetgate_x-yita*delta_weight_forgetgate_x;
    end
    %% 更新weight_inputgate_c
    for num=1:output_num
    for m=1:cell_num
        delta_weight_inputgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*cell_state(m,n-1);
    end
    weight_inputgate_c=weight_inputgate_c-yita*delta_weight_inputgate_c;
    end
    %% 更新weight_forgetgate_c
    for num=1:output_num
    for m=1:cell_num
        delta_weight_forgetgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*cell_state(m,n-1);
    end
    weight_forgetgate_c=weight_forgetgate_c-yita*delta_weight_forgetgate_c;
    end
    %% 更新weight_outputgate_c
    for num=1:output_num
    for m=1:cell_num
        delta_weight_outputgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*tanh(cell_state(:,n))'.*exp(-output_gate_input).*(output_gate.^2)*cell_state(m,n-1);
    end
    weight_outputgate_c=weight_outputgate_c-yita*delta_weight_outputgate_c;
    end
    %% 更新weight_input_h
    temp=train_data(:,n)'*weight_input_x+h_state(:,n-1)'*weight_input_h;
    for num=1:output_num
    for m=1:output_num
        delta_weight_input_h(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*h_state(m,n-1);
    end
    weight_input_h=weight_input_h-yita*delta_weight_input_h;
    end
else
    %% 更新weight_input_x
    temp=train_data(:,n)'*weight_input_x;
    for num=1:output_num
    for m=1:data_length
        delta_weight_input_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train_data(m,n);
    end
    weight_input_x=weight_input_x-yita*delta_weight_input_x;
    end
end
weight_preh_h=weight_preh_h_temp;

end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92

四、程序结果图
Error_Cost图
这里写图片描述
第七天预测值与理论值,第一组为预测值,第二组为实际值
这里写图片描述

猜你喜欢

转载自blog.csdn.net/sinat_22510827/article/details/79858423
今日推荐