回声状态网络ESN

dataset = csvread('D:\matlab2016a\data.csv');
x = dataset(:,1:3);
y = dataset(:,4);

x_train = x(1:145,:);
y_train = y(1:145);
x_test = x(146:end,:);
y_test = y(146:end);

[samples,inputnum] = size(x_train);
hiddennum = 10;
outputnum = 1;

w_in = rand(hiddennum,inputnum);
w_state = rand(hiddennum,hiddennum);
w_back = rand(hiddennum,outputnum);

s(1,:) = zeros(1,hiddennum);
for i = 1:samples
    s(i+1,:) = tanh(w_in*x_train(i,:)'+w_state*s(i,:)'+w_back*y_train(i));
end
s = s(2:end,:);
w_out = (inv(s'*s)*s'*y_train)';

for i = 1:samples
    predict_train(i) = w_out*s(i,:)';
end

train_RMSE = sqrt((sum(predict_train-y_train')^2)/samples);
disp(['Train RMSE = ',num2str(train_RMSE)])

sizepop = size(x_test,1);
for j =1:sizepop
    s(j+samples,:) = tanh(w_in*x_test(j,:)'+w_state*s(j+samples-1,:)'+w_back*y_test(j));
    predict_test(j) = w_out*s(j+samples,:)'; 
end
test_RMSE = sqrt((sum(predict_test-y_test')^2)/sizepop);
disp(['Test RMSE = ',num2str(test_RMSE)])

subplot(1,2,1)
plot(y_train,'-r')
hold on
plot(predict_train,'--b')
subplot(1,2,2)
plot(y_test,'-r')
hold on
plot(predict_test,'--b')

                       Train RMSE = 32.9144
                       Test RMSE = 19.0914

猜你喜欢

转载自blog.csdn.net/qq_42394743/article/details/83086257
esn