长短期记忆网络LSTM(matlab)

load dataset
stop = 0;                                    %控制预测的序列索引(0-倒数第1 1-倒数第2 2-倒数第3,..以此类推)
series = dataset(1:end-stop,1); %导入时间序列数据
timespan = 5;                            %时间跨度(即历史数据条数)

%构造输入输出向量
numdata = size(series,1);
numsample = numdata - timespan - 1;
for i = 1:numsample
    train(:,i) = series(i:i+timespan,:);    
end
test = series(numdata-timespan:numdata,:)'; 

%网络节点数
len = size(train,1)-1;
num = size(train,2);
inputnum = len;
hiddennum = 10;
outputnum = 1;

%网络中门偏置
bias_input_gate = rand(1,hiddennum);     
bias_forget_gate = rand(1,hiddennum);    
bias_output_gate = rand(1,hiddennum);    

%权重初始化
ab = 20;
W_input_x = rand(inputnum,hiddennum)/ab;
W_input_h = rand(outputnum,hiddennum)/ab;

W_inputgate_x = rand(inputnum,hiddennum)/ab;
W_inputgate_c = rand(hiddennum,hiddennum)/ab;

W_forgetgate_x = rand(inputnum,hiddennum)/ab;
W_forgetgate_c = rand(hiddennum,hiddennum)/ab;

W_outputgate_x = rand(inputnum,hiddennum)/ab;
W_outputgate_c = rand(hiddennum,hiddennum)/ab;

W_preh_h = rand(hiddennum,outputnum);

%网络状态初始化
cost_gate = 1e-6;
h_state = rand(outputnum,num);
cell_state = rand(hiddennum,num);

lr = 0.01;
maxgen = 2000;
for i = 1:maxgen
    for j = 1:num
        if j ==1
            gate = tanh(train(1:inputnum,j)'*W_input_x);
            input_gate_input = train(1:inputnum,j)'*W_inputgate_x+bias_input_gate;
            output_gate_input = train(1:inputnum,j)'*W_outputgate_x+bias_output_gate;
            for n = 1:hiddennum
                input_gate(1,n) = 1/(1+exp(-input_gate_input(1,n)));
                output_gate(1,n) = 1/(1+exp(-output_gate_input(1,n)));
            end
            forget_gate=zeros(1,hiddennum);
            forget_gate_input=zeros(1,hiddennum);
            cell_state(:,j)=(input_gate.*gate)';
        else
            gate=tanh(train(1:inputnum,j)'*W_input_x+h_state(:,j-1)'*W_input_h);
            input_gate_input=train(1:inputnum,j)'*W_inputgate_x+cell_state(:,j-1)'*W_inputgate_c+bias_input_gate;
            forget_gate_input=train(1:inputnum,j)'*W_forgetgate_x+cell_state(:,j-1)'*W_forgetgate_c+bias_forget_gate;
            output_gate_input=train(1:inputnum,j)'*W_outputgate_x+cell_state(:,j-1)'*W_outputgate_c+bias_output_gate;
            for n=1:hiddennum
                input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));
                forget_gate(1,n)=1/(1+exp(-forget_gate_input(1,n)));
                output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));
            end
            cell_state(:,j)=(input_gate.*gate+cell_state(:,j-1)'.*forget_gate)';   
        end
        pre_h_state=tanh(cell_state(:,j)').*output_gate;
        h_state(:,j)=(pre_h_state*W_preh_h)';
        
        error = h_state(:,j)-train(end,j);
        error_cost(1,i) = sum(error.^2);
        if error_cost(1,i) < cost_gate
            break;
        else
            [W_input_x,W_input_h,...
             W_inputgate_x,W_inputgate_c,...
             W_forgetgate_x,W_forgetgate_c,...
             W_outputgate_x,W_outputgate_c,W_preh_h] = WeightFunc(j,lr,error,...
                                          W_input_x,W_input_h,...
                                          W_inputgate_x,W_inputgate_c,...
                                          W_forgetgate_x,W_forgetgate_c,...
                                          W_outputgate_x,W_outputgate_c,...
                                          W_preh_h,...
                                          cell_state,h_state,...
                                          input_gate,forget_gate,output_gate,gate,...
                                          train,pre_h_state,...
                                          input_gate_input,output_gate_input,forget_gate_input);                              
        end
    end
end

扫描二维码关注公众号,回复: 3447274 查看本文章

%网络预测
m = num;    
gate=tanh(test(1:inputnum)*W_input_x+h_state(:,m-1)'*W_input_h);
input_gate_input=test(1:inputnum)*W_inputgate_x+cell_state(:,m-1)'*W_inputgate_c+bias_input_gate;
forget_gate_input=test(1:inputnum)*W_forgetgate_x+cell_state(:,m-1)'*W_forgetgate_c+bias_forget_gate;
output_gate_input=test(1:inputnum)*W_outputgate_x+cell_state(:,m-1)'*W_outputgate_c+bias_output_gate;
for n=1:hiddennum
    input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));
    forget_gate(1,n)=1/(1+exp(-forget_gate_input(1,n)));
    output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));
end
cell_state_test=(input_gate.*gate+cell_state(:,m-1)'.*forget_gate)';
pre_h_state=tanh(cell_state_test').*output_gate;
predict_test=(pre_h_state*W_preh_h)';
disp(['真实值 = ',num2str(test(end)),' 预测值 = ',num2str(predict_test)])

---------------------------------------------------------------------------------------------------------------

function [W_input_x,W_input_h,...
             W_inputgate_x,W_inputgate_c,...
             W_forgetgate_x,W_forgetgate_c,...
             W_outputgate_x,W_outputgate_c,W_preh_h] = WeightFunc(n,lr,error,...
                                          W_input_x,W_input_h,...
                                          W_inputgate_x,W_inputgate_c,...
                                          W_forgetgate_x,W_forgetgate_c,...
                                          W_outputgate_x,W_outputgate_c,...
                                          W_preh_h,...
                                          cell_state,h_state,...
                                          input_gate,forget_gate,output_gate,gate,...
                                          train,pre_h_state,...
                                          input_gate_input,output_gate_input,forget_gate_input)

data_length = size(train,1) - 1;          %timespan
data_num = size(train,2);
weight_preh_h_temp=W_preh_h;     %hiddennum x 1

input_num=data_length;
cell_num=size(weight_preh_h_temp,1);
output_num=1;

%% 更新weight_preh_h权重
for m=1:output_num
    delta_weight_preh_h_temp(:,m)=2*error(m,1)*pre_h_state;
end
weight_preh_h_temp=weight_preh_h_temp-lr*delta_weight_preh_h_temp;

%% 更新weight_outputgate_x
for num=1:output_num
    for m=1:data_length
        delta_weight_outputgate_x(m,:)=(2*W_preh_h(:,num)*error(num,1).*tanh(cell_state(:,n)))'.*exp(-output_gate_input).*(output_gate.^2)*train(m,n);
    end
    W_outputgate_x=W_outputgate_x-lr*delta_weight_outputgate_x;
end
%% 更新weight_inputgate_x
for num=1:output_num
    for m=1:data_length
        delta_weight_inputgate_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*train(m,n);
    end
    W_inputgate_x=W_inputgate_x-lr*delta_weight_inputgate_x;
end


if(n~=1)
    %% 更新weight_input_x
    temp=train(1:input_num,n)'*W_input_x+h_state(:,n-1)'*W_input_h;
    for num=1:output_num
        for m=1:data_length
            delta_weight_input_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train(m,n);
        end
        W_input_x=W_input_x-lr*delta_weight_input_x;
    end
    %% 更新weight_forgetgate_x
    for num=1:output_num
        for m=1:data_length
            delta_weight_forgetgate_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*train(m,n);
        end
        W_forgetgate_x=W_forgetgate_x-lr*delta_weight_forgetgate_x;
    end
    %% 更新weight_inputgate_c
    for num=1:output_num
        for m=1:cell_num
            delta_weight_inputgate_c(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*cell_state(m,n-1);
        end
        W_inputgate_c=W_inputgate_c-lr*delta_weight_inputgate_c;
    end
    %% 更新weight_forgetgate_c
    for num=1:output_num
        for m=1:cell_num
            delta_weight_forgetgate_c(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*cell_state(m,n-1);
        end
        W_forgetgate_c=W_forgetgate_c-lr*delta_weight_forgetgate_c;
    end
    %% 更新weight_outputgate_c
    for num=1:output_num
        for m=1:cell_num
            delta_weight_outputgate_c(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*tanh(cell_state(:,n))'.*exp(-output_gate_input).*(output_gate.^2)*cell_state(m,n-1);
        end
        W_outputgate_c=W_outputgate_c-lr*delta_weight_outputgate_c;
    end
    %% 更新weight_input_h
    temp=train(1:input_num,n)'*W_input_x+h_state(:,n-1)'*W_input_h;
    for num=1:output_num
        for m=1:output_num
            delta_weight_input_h(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*h_state(m,n-1);
        end
        W_input_h=W_input_h-lr*delta_weight_input_h;
    end
else
    %% 更新weight_input_x
    temp=train(1:input_num,n)'*W_input_x;
    for num=1:output_num
        for m=1:data_length
            delta_weight_input_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train(m,n);
        end
        W_input_x=W_input_x-lr*delta_weight_input_x;
    end
end
W_preh_h=weight_preh_h_temp;
end

 

猜你喜欢

转载自blog.csdn.net/qq_42394743/article/details/81779233