Keras LSTM 时间序列预测

版权声明:王家林大咖2018年新书《SPARK大数据商业实战三部曲》清华大学出版,清华大学出版社官方旗舰店(天猫)https://qhdx.tmall.com/?spm=a220o.1000855.1997427721.d4918089.4b2a2e5dT6bUsM https://blog.csdn.net/duan_zhihua/article/details/84492222

Keras LSTM 时间序列预测  

international-airline-passengers.csv数据记录:

time,passengers
"1949-01",112
"1949-02",118
"1949-03",132
"1949-04",129
"1949-05",121
"1949-06",135
"1949-07",148
"1949-08",148
"1949-09",136
"1949-10",119
"1949-11",104
"1949-12",118
"1950-01",115
"1950-02",126
"1950-03",141
"1950-04",135
"1950-05",125
"1950-06",149
"1950-07",170
"1950-08",170
"1950-09",158
"1950-10",133
"1950-11",114
"1950-12",140
"1951-01",145
"1951-02",150
"1951-03",178
"1951-04",163
"1951-05",172
"1951-06",178
"1951-07",199
"1951-08",199
"1951-09",184
"1951-10",162
"1951-11",146
"1951-12",166
"1952-01",171
"1952-02",180
"1952-03",193
"1952-04",181
"1952-05",183
"1952-06",218
"1952-07",230
"1952-08",242
"1952-09",209
"1952-10",191
"1952-11",172
"1952-12",194
"1953-01",196
"1953-02",196
"1953-03",236
"1953-04",235
"1953-05",229
"1953-06",243
"1953-07",264
"1953-08",272
"1953-09",237
"1953-10",211
"1953-11",180
"1953-12",201
"1954-01",204
"1954-02",188
"1954-03",235
"1954-04",227
"1954-05",234
"1954-06",264
"1954-07",302
"1954-08",293
"1954-09",259
"1954-10",229
"1954-11",203
"1954-12",229
"1955-01",242
"1955-02",233
"1955-03",267
"1955-04",269
"1955-05",270
"1955-06",315
"1955-07",364
"1955-08",347
"1955-09",312
"1955-10",274
"1955-11",237
"1955-12",278
"1956-01",284
"1956-02",277
"1956-03",317
"1956-04",313
"1956-05",318
"1956-06",374
"1956-07",413
"1956-08",405
"1956-09",355
"1956-10",306
"1956-11",271
"1956-12",306
"1957-01",315
"1957-02",301
"1957-03",356
"1957-04",348
"1957-05",355
"1957-06",422
"1957-07",465
"1957-08",467
"1957-09",404
"1957-10",347
"1957-11",305
"1957-12",336
"1958-01",340
"1958-02",318
"1958-03",362
"1958-04",348
"1958-05",363
"1958-06",435
"1958-07",491
"1958-08",505
"1958-09",404
"1958-10",359
"1958-11",310
"1958-12",337
"1959-01",360
"1959-02",342
"1959-03",406
"1959-04",396
"1959-05",420
"1959-06",472
"1959-07",548
"1959-08",559
"1959-09",463
"1959-10",407
"1959-11",362
"1959-12",405
"1960-01",417
"1960-02",391
"1960-03",419
"1960-04",461
"1960-05",472
"1960-06",535
"1960-07",622
"1960-08",606
"1960-09",508
"1960-10",461
"1960-11",390
"1960-12",432

Keras LSTM时间序列lstm_airline_predict.py:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import LSTM, Dense, Activation


def load_data(file_name, sequence_length=10, split=0.8):
    df = pd.read_csv(file_name, sep=',', usecols=[1])
    data_all = np.array(df).astype(float)
    scaler = MinMaxScaler()
    data_all = scaler.fit_transform(data_all)
    data = []
    for i in range(len(data_all) - sequence_length - 1):
        data.append(data_all[i: i + sequence_length + 1])
    reshaped_data = np.array(data).astype('float64')
    np.random.shuffle(reshaped_data)
    # 对x进行统一归一化,而y则不归一化
    x = reshaped_data[:, :-1]
    y = reshaped_data[:, -1]
    split_boundary = int(reshaped_data.shape[0] * split)
    train_x = x[: split_boundary]
    test_x = x[split_boundary:]

    train_y = y[: split_boundary]
    test_y = y[split_boundary:]

    return train_x, train_y, test_x, test_y, scaler


def build_model():
    # input_dim是输入的train_x的最后一个维度,train_x的维度为(n_samples, time_steps, input_dim)
    model = Sequential()
    model.add(LSTM(input_dim=1, output_dim=50, return_sequences=True))
    print(model.layers)
    model.add(LSTM(100, return_sequences=False))
    model.add(Dense(output_dim=1))
    model.add(Activation('linear'))

    model.compile(loss='mse', optimizer='rmsprop')
    return model


def train_model(train_x, train_y, test_x, test_y):
    model = build_model()

    try:
        model.fit(train_x, train_y, batch_size=512, nb_epoch=30, validation_split=0.1)
        predict = model.predict(test_x)
        predict = np.reshape(predict, (predict.size, ))
    except KeyboardInterrupt:
        print(predict)
        print(test_y)
    print(predict)
    print(test_y)
    try:
        fig = plt.figure(1)
        plt.plot(predict, 'r:')
        plt.plot(test_y, 'g-')
        plt.legend(['predict', 'true'])
    except Exception as e:
        print(e)
    return predict, test_y


if __name__ == '__main__':
    train_x, train_y, test_x, test_y, scaler = load_data('international-airline-passengers.csv')
    train_x = np.reshape(train_x, (train_x.shape[0], train_x.shape[1], 1))
    test_x = np.reshape(test_x, (test_x.shape[0], test_x.shape[1], 1))
    predict_y, test_y = train_model(train_x, train_y, test_x, test_y)
    predict_y = scaler.inverse_transform([[i] for i in predict_y])
    test_y = scaler.inverse_transform(test_y)
    fig2 = plt.figure(2)
    plt.plot(predict_y, 'g:')
    plt.plot(test_y, 'r-')
    plt.show()

运行结果:

Epoch 1/30
95/95 [==============================] - 5s 53ms/step - loss: 0.1793 - val_loss: 0.1028
Epoch 2/30
95/95 [==============================] - 0s 412us/step - loss: 0.1015 - val_loss: 0.0528
Epoch 3/30
95/95 [==============================] - 0s 353us/step - loss: 0.0532 - val_loss: 0.0183
Epoch 4/30
95/95 [==============================] - 0s 359us/step - loss: 0.0204 - val_loss: 0.0113
Epoch 5/30
95/95 [==============================] - 0s 448us/step - loss: 0.0145 - val_loss: 0.0119
Epoch 6/30
95/95 [==============================] - 0s 507us/step - loss: 0.0140 - val_loss: 0.0114
Epoch 7/30
95/95 [==============================] - 0s 439us/step - loss: 0.0135 - val_loss: 0.0120
Epoch 8/30
95/95 [==============================] - 0s 373us/step - loss: 0.0132 - val_loss: 0.0118
Epoch 9/30
95/95 [==============================] - 0s 454us/step - loss: 0.0129 - val_loss: 0.0127
Epoch 10/30
95/95 [==============================] - 0s 413us/step - loss: 0.0129 - val_loss: 0.0127
Epoch 11/30
95/95 [==============================] - 0s 418us/step - loss: 0.0129 - val_loss: 0.0147
Epoch 12/30
95/95 [==============================] - 0s 369us/step - loss: 0.0139 - val_loss: 0.0145
Epoch 13/30
95/95 [==============================] - 0s 485us/step - loss: 0.0141 - val_loss: 0.0182
Epoch 14/30
95/95 [==============================] - 0s 459us/step - loss: 0.0166 - val_loss: 0.0146
Epoch 15/30
95/95 [==============================] - 0s 549us/step - loss: 0.0138 - val_loss: 0.0168
Epoch 16/30
95/95 [==============================] - 0s 423us/step - loss: 0.0149 - val_loss: 0.0141
Epoch 17/30
95/95 [==============================] - 0s 401us/step - loss: 0.0129 - val_loss: 0.0155
Epoch 18/30
95/95 [==============================] - 0s 383us/step - loss: 0.0134 - val_loss: 0.0141
Epoch 19/30
95/95 [==============================] - 0s 328us/step - loss: 0.0125 - val_loss: 0.0154
Epoch 20/30
95/95 [==============================] - 0s 401us/step - loss: 0.0130 - val_loss: 0.0144
Epoch 21/30
95/95 [==============================] - 0s 338us/step - loss: 0.0124 - val_loss: 0.0158
Epoch 22/30
95/95 [==============================] - 0s 359us/step - loss: 0.0131 - val_loss: 0.0148
Epoch 23/30
95/95 [==============================] - 0s 338us/step - loss: 0.0126 - val_loss: 0.0164
Epoch 24/30
95/95 [==============================] - 0s 380us/step - loss: 0.0135 - val_loss: 0.0150
Epoch 25/30
95/95 [==============================] - 0s 378us/step - loss: 0.0127 - val_loss: 0.0167
Epoch 26/30
95/95 [==============================] - 0s 541us/step - loss: 0.0137 - val_loss: 0.0151
Epoch 27/30
95/95 [==============================] - 0s 528us/step - loss: 0.0127 - val_loss: 0.0166
Epoch 28/30
95/95 [==============================] - 0s 423us/step - loss: 0.0134 - val_loss: 0.0150
Epoch 29/30
95/95 [==============================] - 0s 515us/step - loss: 0.0125 - val_loss: 0.0164
Epoch 30/30
95/95 [==============================] - 0s 457us/step - loss: 0.0131 - val_loss: 0.0150
[0.6991743  0.4155811  0.43763575 0.1943914  0.24489456 0.43544254
 0.728908   0.27704275 0.7644203  0.24740852 0.58411294 0.33986062
 0.28997922 0.13274276 0.74714196 0.5237809  0.36774576 0.5282971
 0.23951268 0.6239692  0.15398878 0.4958876  0.10568523 0.55706674
 0.32880494 0.60746497 0.294434  ]
[[1.        ]
 [0.25675676]
 [0.4034749 ]
 [0.11969112]
 [0.17374517]
 [0.58108108]
 [0.4980695 ]
 [0.25675676]
 [0.55405405]
 [0.17760618]
 [0.5       ]
 [0.31853282]
 [0.2992278 ]
 [0.01930502]
 [0.58108108]
 [0.48648649]
 [0.4015444 ]
 [0.38030888]
 [0.13127413]
 [0.61003861]
 [0.18339768]
 [0.38996139]
 [0.12741313]
 [0.63899614]
 [0.40733591]
 [0.87837838]
 [0.20656371]]

猜你喜欢

转载自blog.csdn.net/duan_zhihua/article/details/84492222