示例如何使用带状态LSTM模型以及它的无状态对应项如何执行。
代码注释
'''Example script showing how to use a stateful LSTM model and how its stateless counterpart performs. 示例如何使用带状态LSTM模型以及它的无状态对应项如何执行。 More documentation about the Keras LSTM model can be found at 关于Keras的LSTM模型更多文档见 https://keras.io/layers/recurrent/#lstm The models are trained on an input/output pair, where the input is a generated uniformly distributed random sequence of length = "input_len", and the output is a moving average of the input with window length = "tsteps". Both "input_len" and "tsteps" are defined in the "editable parameters" section. 在输入/输出对上对模型进行训练,其中输入是一个生成的均匀分布的随机序列长度=“input_len”,输出 是窗口长度=“tsteps”的输入的移动平均值,在“可编辑参数”部分中定义了“input_len”和“tsteps”。 A larger "tsteps" value means that the LSTM will need more memory to figure out the input-output relationship. 价值较大的“tsteps”意味着lstm需要更多的内存找到的输入输出关系。 This memory length is controlled by the "lahead" variable (more details below). 这个内存长度由“lahead”变量控制(下面详细说明)。 The rest of the parameters are: 其余参数: - input_len: the length of the generated input sequence - input_len: 生成的输入序列的长度 - lahead: the input sequence length that the LSTM is trained on for each output point 基于输入序列长度,为输出点训练LSTM - batch_size, epochs: same parameters as in the model.fit(...) function - batch_size, epochs: 与模型中的参数一样, model.fit(...)函数 When lahead > 1, the model input is preprocessed to a "rolling window view" of the data, with the window length = "lahead". This is similar to sklearn's "view_as_windows" with "window_shape" being a single number 当lahead>1时,模型输入被预处理为数据的“rolling window view”,窗口长度=“lahead”。 类似于sklearn的“view_as_windows”,其中“window_shape”是单个数字。 Ref: http://scikit-image.org/docs/0.10.x/api/skimage.util.html#view-as-windows When lahead < tsteps, only the stateful LSTM converges because its statefulness allows it to see beyond the capability that lahead gave it to fit the n-point average. The stateless LSTM does not have this capability, and hence is limited by its "lahead" parameter, which is not sufficient to see the n-point average. 当lahead < tsteps时,只有状态LSTM收敛,因为它的状态允许它超越lahead赋予 它的适合N点平均的能力。无状态LSTM不具有这种能力,因此受限于它的“lahead”参数,它不足以看到N点平均值。 When lahead >= tsteps, both the stateful and stateless LSTM converge 当lahead>=tsteps,状态和无状态的lstm收敛 ''' from __future__ import print_function import numpy as np import matplotlib.pyplot as plt import pandas as pd from keras.models import Sequential from keras.layers import Dense, LSTM # ---------------------------------------------------------- # EDITABLE PARAMETERS # 可编辑参数 # Read the documentation in the script head for more details # 阅读脚本头中的文档以获取更多细节 # ---------------------------------------------------------- # length of input # 输出长度 input_len = 1000 # The window length of the moving average used to generate # the output from the input in the input/output pair used # to train the LSTM # 用于输出LSTM输入/输出对输入的输出的。 # 移动平均的窗口长度用于通过输入生成输出,基于训练LSTM的输入/输出对, # e.g. if tsteps=2 and input=[1, 2, 3, 4, 5], # then output=[1.5, 2.5, 3.5, 4.5] tsteps = 2 # The input sequence length that the LSTM is trained on for each output point # 输入序列长度,基于它为每个输出点训练LSTM lahead = 1 # training parameters passed to "model.fit(...)" # 传递到"model.fit(...)"(函数中的)训练参数 batch_size = 1 epochs = 10 # ------------ # MAIN PROGRAM # 主程序 # ------------ print("*" * 33) if lahead >= tsteps: print("STATELESS LSTM WILL ALSO CONVERGE") else: print("STATELESS LSTM WILL NOT CONVERGE") print("*" * 33) np.random.seed(1986) print('Generating Data...') def gen_uniform_amp(amp=1, xn=10000): """Generates uniform random data between -amp and +amp and of length xn 在amp and +amp和长度xn之间产生均匀的随机数据。 Arguments: 参数 amp: maximum/minimum range of uniform data amp: 在最小值和最大值范围之间的均匀随机数据 xn: length of series xn: 系列长度 """ data_input = np.random.uniform(-1 * amp, +1 * amp, xn) data_input = pd.DataFrame(data_input) return data_input # Since the output is a moving average of the input, # the first few points of output will be NaN # and will be dropped from the generated data # before training the LSTM. # 由于输出是输入的移动平均值,输出的前几个点将是特殊数值,并且在训练LSTM之前从生成的数据中丢失。 # NaN,是Not a Number的缩写,在IEEE浮点数算术标准(IEEE 754)中定义,表示一些特殊数值(无穷与非数值(NaN)),为许多CPU与浮点运算器所采用。 # Also, when lahead > 1, # the preprocessing step later of "rolling window view" # will also cause some points to be lost. # 此外,当lahead > 1时,“滚动窗口视图”的预处理步骤也会导致一些点丢失 # For aesthetic reasons, # in order to maintain generated data length = input_len after pre-processing, # add a few points to account for the values that will be lost. # 出于美观的原因,为了在预处理之后保持生成的数据长度=input_len(输入数据),增加一些点来说明丢失的值。 to_drop = max(tsteps - 1, lahead - 1) data_input = gen_uniform_amp(amp=0.1, xn=input_len + to_drop) # set the target to be a N-point average of the input # 将目标设为输入的n点的平均值 expected_output = data_input.rolling(window=tsteps, center=False).mean() # when lahead > 1, need to convert the input to "rolling window view" # 当lahead>1时,需要将输入转换为“滚动窗口视图”。 # https://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html if lahead > 1: data_input = np.repeat(data_input.values, repeats=lahead, axis=1) data_input = pd.DataFrame(data_input) for i, c in enumerate(data_input.columns): data_input[c] = data_input[c].shift(i) # drop the nan # 丢弃特殊数值 expected_output = expected_output[to_drop:] data_input = data_input[to_drop:] print('Input shape:', data_input.shape) print('Output shape:', expected_output.shape) print('Input head: ') print(data_input.head()) print('Output head: ') print(expected_output.head()) print('Input tail: ') print(data_input.tail()) print('Output tail: ') print(expected_output.tail()) print('Plotting input and expected output') plt.plot(data_input[0][:10], '.') plt.plot(expected_output[0][:10], '-') plt.legend(['Input', 'Expected output']) plt.title('Input') plt.show() def create_model(stateful): model = Sequential() model.add(LSTM(20, input_shape=(lahead, 1), batch_size=batch_size, stateful=stateful)) model.add(Dense(1)) model.compile(loss='mse', optimizer='adam') return model print('Creating Stateful Model...') model_stateful = create_model(stateful=True) # split train/test data # 分割训练/测试集 def split_data(x, y, ratio=0.8): to_train = int(input_len * ratio) # tweak to match with batch_size # 调整与batch_size匹配 to_train -= to_train % batch_size x_train = x[:to_train] y_train = y[:to_train] x_test = x[to_train:] y_test = y[to_train:] # tweak to match with batch_size # 调整与batch_size匹配 to_drop = x.shape[0] % batch_size if to_drop > 0: x_test = x_test[:-1 * to_drop] y_test = y_test[:-1 * to_drop] # some reshaping # 重组 reshape_3 = lambda x: x.values.reshape((x.shape[0], x.shape[1], 1)) x_train = reshape_3(x_train) x_test = reshape_3(x_test) reshape_2 = lambda x: x.values.reshape((x.shape[0], 1)) y_train = reshape_2(y_train) y_test = reshape_2(y_test) return (x_train, y_train), (x_test, y_test) (x_train, y_train), (x_test, y_test) = split_data(data_input, expected_output) print('x_train.shape: ', x_train.shape) print('y_train.shape: ', y_train.shape) print('x_test.shape: ', x_test.shape) print('y_test.shape: ', y_test.shape) print('Training') for i in range(epochs): print('Epoch', i + 1, '/', epochs) # Note that the last state for sample i in a batch will # be used as initial state for sample i in the next batch. # 请注意,批处理中的样本i的最后状态将作为下一批中的样本i的初始状态。 # Thus we are simultaneously training on batch_size series with # lower resolution than the original series contained in data_input. # 因此,我们同时训练batch_size系列,比data_input中包含的原始序列具有更低的分辨率。 # Each of these series are offset by one step and can be # extracted with data_input[i::batch_size]. # 这些系列中的每一个都偏移一步,并且可以用data_input[i::batch_size]来提取。 model_stateful.fit(x_train, y_train, batch_size=batch_size, epochs=1, verbose=1, validation_data=(x_test, y_test), shuffle=False) model_stateful.reset_states() print('Predicting') predicted_stateful = model_stateful.predict(x_test, batch_size=batch_size) print('Creating Stateless Model...') model_stateless = create_model(stateful=False) print('Training') model_stateless.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test), shuffle=False) print('Predicting') predicted_stateless = model_stateless.predict(x_test, batch_size=batch_size) # ---------------------------- print('Plotting Results') plt.subplot(3, 1, 1) plt.plot(y_test) plt.title('Expected') plt.subplot(3, 1, 2) # drop the first "tsteps-1" because it is not possible to predict them # since the "previous" timesteps to use do not exist # 放弃第一个“TSTEPS-1”,因为不可能预测它们,因为以前使用的“时间步长”不存在。 plt.plot((y_test - predicted_stateful).flatten()[tsteps - 1:]) plt.title('Stateful: Expected - Predicted') plt.subplot(3, 1, 3) plt.plot((y_test - predicted_stateless).flatten()) plt.title('Stateless: Expected - Predicted') plt.show()
代码执行
C:\ProgramData\Anaconda3\python.exe E:/keras-master/examples/lstm_stateful.py Using TensorFlow backend. ********************************* STATELESS LSTM WILL NOT CONVERGE ********************************* Generating Data... Input shape: (1000, 1) Output shape: (1000, 1) Input head: 0 1 -0.084532 2 0.021696 3 0.079500 4 0.008981 5 0.040544 Output head: 0 1 -0.035379 2 -0.031418 3 0.050598 4 0.044240 5 0.024763 Input tail: 0 996 0.010251 997 -0.027833 998 0.003984 999 0.028471 1000 -0.057877 Output tail: 0 996 0.025187 997 -0.008791 998 -0.011925 999 0.016227 1000 -0.014703 Plotting input and expected output Creating Stateful Model... x_train.shape: (800, 1, 1) y_train.shape: (800, 1) x_test.shape: (200, 1, 1) y_test.shape: (200, 1) Training Epoch 1 / 10 Train on 800 samples, validate on 200 samples Epoch 1/1 1/800 [..............................] - ETA: 23:10 - loss: 8.5216e-04 8/800 [..............................] - ETA: 2:57 - loss: 9.8957e-04 13/800 [..............................] - ETA: 1:51 - loss: 8.4976e-04 20/800 [..............................] - ETA: 1:13 - loss: 7.5000e-04 29/800 [>.............................] - ETA: 52s - loss: 8.3234e-04 35/800 [>.............................] - ETA: 43s - loss: 9.5878e-04 42/800 [>.............................] - ETA: 37s - loss: 0.0010 49/800 [>.............................] - ETA: 32s - loss: 0.0010 57/800 [=>............................] - ETA: 28s - loss: 9.2958e-04 64/800 [=>............................] - ETA: 25s - loss: 0.0010 72/800 [=>............................] - ETA: 22s - loss: 0.0010 80/800 [==>...........................] - ETA: 20s - loss: 0.0010 88/800 [==>...........................] - ETA: 19s - loss: 9.7788e-04 93/800 [==>...........................] - ETA: 18s - loss: 9.7039e-04 98/800 [==>...........................] - ETA: 17s - loss: 0.0010 107/800 [===>..........................] - ETA: 16s - loss: 9.8123e-04 115/800 [===>..........................] - ETA: 15s - loss: 9.6177e-04 122/800 [===>..........................] - ETA: 14s - loss: 9.4016e-04 130/800 [===>..........................] - ETA: 13s - loss: 9.4474e-04 138/800 [====>.........................] - ETA: 13s - loss: 9.5132e-04 146/800 [====>.........................] - ETA: 12s - loss: 9.3231e-04 154/800 [====>.........................] - ETA: 12s - loss: 9.3191e-04 160/800 [=====>........................] - ETA: 11s - loss: 9.1047e-04 166/800 [=====>........................] - ETA: 11s - loss: 8.9187e-04 172/800 [=====>........................] - ETA: 11s - loss: 8.9586e-04 177/800 [=====>........................] - ETA: 10s - loss: 9.0390e-04 183/800 [=====>........................] - ETA: 10s - loss: 8.8498e-04 189/800 [======>.......................] - ETA: 10s - loss: 9.0886e-04 194/800 [======>.......................] - ETA: 10s - loss: 8.9147e-04 200/800 [======>.......................] - ETA: 9s - loss: 9.0357e-04 206/800 [======>.......................] - ETA: 9s - loss: 8.8145e-04 213/800 [======>.......................] - ETA: 9s - loss: 9.0894e-04 221/800 [=======>......................] - ETA: 9s - loss: 9.0075e-04 229/800 [=======>......................] - ETA: 8s - loss: 9.1960e-04 237/800 [=======>......................] - ETA: 8s - loss: 9.1508e-04 244/800 [========>.....................] - ETA: 8s - loss: 8.9347e-04 249/800 [========>.....................] - ETA: 8s - loss: 8.8506e-04 256/800 [========>.....................] - ETA: 7s - loss: 8.8218e-04 264/800 [========>.....................] - ETA: 7s - loss: 8.6197e-04 270/800 [=========>....................] - ETA: 7s - loss: 8.6244e-04 277/800 [=========>....................] - ETA: 7s - loss: 8.5799e-04 285/800 [=========>....................] - ETA: 7s - loss: 8.4244e-04 293/800 [=========>....................] - ETA: 6s - loss: 8.2892e-04 301/800 [==========>...................] - ETA: 6s - loss: 8.1710e-04 308/800 [==========>...................] - ETA: 6s - loss: 8.1387e-04 316/800 [==========>...................] - ETA: 6s - loss: 8.0678e-04 323/800 [===========>..................] - ETA: 6s - loss: 8.0260e-04 332/800 [===========>..................] - ETA: 6s - loss: 8.0745e-04 336/800 [===========>..................] - ETA: 6s - loss: 8.0657e-04 345/800 [===========>..................] - ETA: 5s - loss: 7.9895e-04 354/800 [============>.................] - ETA: 5s - loss: 7.8875e-04 362/800 [============>.................] - ETA: 5s - loss: 7.8790e-04 370/800 [============>.................] - ETA: 5s - loss: 7.9508e-04 378/800 [=============>................] - ETA: 5s - loss: 7.8973e-04 386/800 [=============>................] - ETA: 5s - loss: 7.8466e-04 394/800 [=============>................] - ETA: 4s - loss: 7.9312e-04 402/800 [==============>...............] - ETA: 4s - loss: 7.8996e-04 411/800 [==============>...............] - ETA: 4s - loss: 7.7805e-04 420/800 [==============>...............] - ETA: 4s - loss: 7.6865e-04 429/800 [===============>..............] - ETA: 4s - loss: 7.6231e-04 437/800 [===============>..............] - ETA: 4s - loss: 7.5908e-04 445/800 [===============>..............] - ETA: 4s - loss: 7.4860e-04 452/800 [===============>..............] - ETA: 3s - loss: 7.4074e-04 460/800 [================>.............] - ETA: 3s - loss: 7.3031e-04 468/800 [================>.............] - ETA: 3s - loss: 7.2087e-04 476/800 [================>.............] - ETA: 3s - loss: 7.2082e-04 483/800 [=================>............] - ETA: 3s - loss: 7.2121e-04 491/800 [=================>............] - ETA: 3s - loss: 7.1069e-04 499/800 [=================>............] - ETA: 3s - loss: 7.0494e-04 507/800 [==================>...........] - ETA: 3s - loss: 7.0049e-04 515/800 [==================>...........] - ETA: 3s - loss: 6.9607e-04 523/800 [==================>...........] - ETA: 2s - loss: 6.9888e-04 531/800 [==================>...........] - ETA: 2s - loss: 6.9627e-04 539/800 [===================>..........] - ETA: 2s - loss: 7.0178e-04 545/800 [===================>..........] - ETA: 2s - loss: 6.9792e-04 551/800 [===================>..........] - ETA: 2s - loss: 6.9527e-04 556/800 [===================>..........] - ETA: 2s - loss: 6.9119e-04 562/800 [====================>.........] - ETA: 2s - loss: 6.8701e-04 568/800 [====================>.........] - ETA: 2s - loss: 6.8762e-04 575/800 [====================>.........] - ETA: 2s - loss: 6.8512e-04 583/800 [====================>.........] - ETA: 2s - loss: 6.8444e-04 590/800 [=====================>........] - ETA: 2s - loss: 6.8151e-04 595/800 [=====================>........] - ETA: 2s - loss: 6.8386e-04 599/800 [=====================>........] - ETA: 2s - loss: 6.8208e-04 605/800 [=====================>........] - ETA: 2s - loss: 6.7642e-04 610/800 [=====================>........] - ETA: 1s - loss: 6.7531e-04 616/800 [======================>.......] - ETA: 1s - loss: 6.7193e-04 622/800 [======================>.......] - ETA: 1s - loss: 6.6764e-04 628/800 [======================>.......] - ETA: 1s - loss: 6.6446e-04 635/800 [======================>.......] - ETA: 1s - loss: 6.6137e-04 642/800 [=======================>......] - ETA: 1s - loss: 6.5629e-04 649/800 [=======================>......] - ETA: 1s - loss: 6.5130e-04 656/800 [=======================>......] - ETA: 1s - loss: 6.4967e-04 664/800 [=======================>......] - ETA: 1s - loss: 6.4691e-04 672/800 [========================>.....] - ETA: 1s - loss: 6.4803e-04 680/800 [========================>.....] - ETA: 1s - loss: 6.4319e-04 688/800 [========================>.....] - ETA: 1s - loss: 6.4213e-04 696/800 [=========================>....] - ETA: 1s - loss: 6.3881e-04 704/800 [=========================>....] - ETA: 0s - loss: 6.3280e-04 713/800 [=========================>....] - ETA: 0s - loss: 6.3063e-04 721/800 [==========================>...] - ETA: 0s - loss: 6.2775e-04 727/800 [==========================>...] - ETA: 0s - loss: 6.2472e-04 732/800 [==========================>...] - ETA: 0s - loss: 6.2265e-04 739/800 [==========================>...] - ETA: 0s - loss: 6.1951e-04 745/800 [==========================>...] - ETA: 0s - loss: 6.1688e-04 750/800 [===========================>..] - ETA: 0s - loss: 6.1430e-04 757/800 [===========================>..] - ETA: 0s - loss: 6.1158e-04 764/800 [===========================>..] - ETA: 0s - loss: 6.0985e-04 772/800 [===========================>..] - ETA: 0s - loss: 6.0831e-04 780/800 [============================>.] - ETA: 0s - loss: 6.0951e-04 788/800 [============================>.] - ETA: 0s - loss: 6.0644e-04 795/800 [============================>.] - ETA: 0s - loss: 6.0409e-04 800/800 [==============================] - 8s 10ms/step - loss: 6.0417e-04 - val_loss: 5.2337e-04 Epoch 2 / 10 Train on 800 samples, validate on 200 samples Epoch 1/1 1/800 [..............................] - ETA: 4s - loss: 9.3058e-05 8/800 [..............................] - ETA: 5s - loss: 3.6670e-04 785/800 [============================>.] - ETA: 0s - loss: 8.4495e-04 794/800 [============================>.] - ETA: 0s - loss: 8.4709e-04 800/800 [==============================] - 6s 8ms/step - loss: 8.5014e-04 - val_loss: 9.6398e-04 Predicting Plotting Results Process finished with exit code 0
Keras详细介绍
中文:http://keras-cn.readthedocs.io/en/latest/
实例下载
https://github.com/keras-team/keras
https://github.com/keras-team/keras/tree/master/examples
完整项目下载
方便没积分童鞋,请加企鹅452205574,共享文件夹。
包括:代码、数据集合(图片)、已生成model、安装库文件等。