官网实例详解4.20(lstm_stateful.py)-keras学习笔记四

示例如何使用带状态LSTM模型以及它的无状态对应项如何执行。

 

Keras实例目录

代码注释

'''Example script showing how to use a stateful LSTM model
and how its stateless counterpart performs.
示例如何使用带状态LSTM模型以及它的无状态对应项如何执行。

More documentation about the Keras LSTM model can be found at
关于Keras的LSTM模型更多文档见
https://keras.io/layers/recurrent/#lstm

The models are trained on an input/output pair, where
the input is a generated uniformly distributed
random sequence of length = "input_len",
and the output is a moving average of the input with window length = "tsteps".
Both "input_len" and "tsteps" are defined in the "editable parameters" section.
在输入/输出对上对模型进行训练,其中输入是一个生成的均匀分布的随机序列长度=“input_len”,输出
是窗口长度=“tsteps”的输入的移动平均值,在“可编辑参数”部分中定义了“input_len”和“tsteps”。

A larger "tsteps" value means that the LSTM will need more memory
to figure out the input-output relationship.
价值较大的“tsteps”意味着lstm需要更多的内存找到的输入输出关系。
This memory length is controlled by the "lahead" variable (more details below).
这个内存长度由“lahead”变量控制(下面详细说明)。

The rest of the parameters are:
其余参数:
- input_len: the length of the generated input sequence
- input_len: 生成的输入序列的长度
- lahead: the input sequence length that the LSTM
  is trained on for each output point
  基于输入序列长度,为输出点训练LSTM
- batch_size, epochs: same parameters as in the model.fit(...) function
- batch_size, epochs: 与模型中的参数一样, model.fit(...)函数

When lahead > 1, the model input is preprocessed to a "rolling window view"
of the data, with the window length = "lahead".
This is similar to sklearn's "view_as_windows"
with "window_shape" being a single number
当lahead>1时,模型输入被预处理为数据的“rolling window view”,窗口长度=“lahead”。
类似于sklearn的“view_as_windows”,其中“window_shape”是单个数字。
Ref: http://scikit-image.org/docs/0.10.x/api/skimage.util.html#view-as-windows

When lahead < tsteps, only the stateful LSTM converges because its
statefulness allows it to see beyond the capability that lahead
gave it to fit the n-point average. The stateless LSTM does not have
this capability, and hence is limited by its "lahead" parameter,
which is not sufficient to see the n-point average.
当lahead < tsteps时,只有状态LSTM收敛,因为它的状态允许它超越lahead赋予
它的适合N点平均的能力。无状态LSTM不具有这种能力,因此受限于它的“lahead”参数,它不足以看到N点平均值。

When lahead >= tsteps, both the stateful and stateless LSTM converge
当lahead>=tsteps,状态和无状态的lstm收敛
'''
from __future__ import print_function
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense, LSTM

# ----------------------------------------------------------
# EDITABLE PARAMETERS
#  可编辑参数
# Read the documentation in the script head for more details
#  阅读脚本头中的文档以获取更多细节
# ----------------------------------------------------------

# length of input
# 输出长度
input_len = 1000

# The window length of the moving average used to generate
# the output from the input in the input/output pair used
# to train the LSTM
# 用于输出LSTM输入/输出对输入的输出的。
# 移动平均的窗口长度用于通过输入生成输出,基于训练LSTM的输入/输出对,
# e.g. if tsteps=2 and input=[1, 2, 3, 4, 5],
#      then output=[1.5, 2.5, 3.5, 4.5]
tsteps = 2

# The input sequence length that the LSTM is trained on for each output point
# 输入序列长度,基于它为每个输出点训练LSTM
lahead = 1

# training parameters passed to "model.fit(...)"
# 传递到"model.fit(...)"(函数中的)训练参数
batch_size = 1
epochs = 10

# ------------
# MAIN PROGRAM
# 主程序
# ------------

print("*" * 33)
if lahead >= tsteps:
    print("STATELESS LSTM WILL ALSO CONVERGE")
else:
    print("STATELESS LSTM WILL NOT CONVERGE")
print("*" * 33)

np.random.seed(1986)

print('Generating Data...')


def gen_uniform_amp(amp=1, xn=10000):
    """Generates uniform random data between
    -amp and +amp
    and of length xn
    在amp and +amp和长度xn之间产生均匀的随机数据。

    Arguments:
    参数
        amp: maximum/minimum range of uniform data
        amp: 在最小值和最大值范围之间的均匀随机数据
        xn: length of series
        xn: 系列长度
    """
    data_input = np.random.uniform(-1 * amp, +1 * amp, xn)
    data_input = pd.DataFrame(data_input)
    return data_input

# Since the output is a moving average of the input,
# the first few points of output will be NaN
# and will be dropped from the generated data
# before training the LSTM.
# 由于输出是输入的移动平均值,输出的前几个点将是特殊数值,并且在训练LSTM之前从生成的数据中丢失。
# NaN,是Not a Number的缩写,在IEEE浮点数算术标准(IEEE 754)中定义,表示一些特殊数值(无穷与非数值(NaN)),为许多CPU与浮点运算器所采用。
# Also, when lahead > 1,
# the preprocessing step later of "rolling window view"
# will also cause some points to be lost.
# 此外,当lahead > 1时,“滚动窗口视图”的预处理步骤也会导致一些点丢失
# For aesthetic reasons,
# in order to maintain generated data length = input_len after pre-processing,
# add a few points to account for the values that will be lost.
# 出于美观的原因,为了在预处理之后保持生成的数据长度=input_len(输入数据),增加一些点来说明丢失的值。
to_drop = max(tsteps - 1, lahead - 1)
data_input = gen_uniform_amp(amp=0.1, xn=input_len + to_drop)

# set the target to be a N-point average of the input
# 将目标设为输入的n点的平均值
expected_output = data_input.rolling(window=tsteps, center=False).mean()

# when lahead > 1, need to convert the input to "rolling window view"
# 当lahead>1时,需要将输入转换为“滚动窗口视图”。
# https://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html
if lahead > 1:
    data_input = np.repeat(data_input.values, repeats=lahead, axis=1)
    data_input = pd.DataFrame(data_input)
    for i, c in enumerate(data_input.columns):
        data_input[c] = data_input[c].shift(i)

# drop the nan
# 丢弃特殊数值
expected_output = expected_output[to_drop:]
data_input = data_input[to_drop:]

print('Input shape:', data_input.shape)
print('Output shape:', expected_output.shape)
print('Input head: ')
print(data_input.head())
print('Output head: ')
print(expected_output.head())
print('Input tail: ')
print(data_input.tail())
print('Output tail: ')
print(expected_output.tail())

print('Plotting input and expected output')
plt.plot(data_input[0][:10], '.')
plt.plot(expected_output[0][:10], '-')
plt.legend(['Input', 'Expected output'])
plt.title('Input')
plt.show()


def create_model(stateful):
    model = Sequential()
    model.add(LSTM(20,
              input_shape=(lahead, 1),
              batch_size=batch_size,
              stateful=stateful))
    model.add(Dense(1))
    model.compile(loss='mse', optimizer='adam')
    return model

print('Creating Stateful Model...')
model_stateful = create_model(stateful=True)


# split train/test data
# 分割训练/测试集
def split_data(x, y, ratio=0.8):
    to_train = int(input_len * ratio)
    # tweak to match with batch_size
    # 调整与batch_size匹配
    to_train -= to_train % batch_size

    x_train = x[:to_train]
    y_train = y[:to_train]
    x_test = x[to_train:]
    y_test = y[to_train:]

    # tweak to match with batch_size
    # 调整与batch_size匹配
    to_drop = x.shape[0] % batch_size
    if to_drop > 0:
        x_test = x_test[:-1 * to_drop]
        y_test = y_test[:-1 * to_drop]

    # some reshaping
    # 重组
    reshape_3 = lambda x: x.values.reshape((x.shape[0], x.shape[1], 1))
    x_train = reshape_3(x_train)
    x_test = reshape_3(x_test)

    reshape_2 = lambda x: x.values.reshape((x.shape[0], 1))
    y_train = reshape_2(y_train)
    y_test = reshape_2(y_test)

    return (x_train, y_train), (x_test, y_test)


(x_train, y_train), (x_test, y_test) = split_data(data_input, expected_output)
print('x_train.shape: ', x_train.shape)
print('y_train.shape: ', y_train.shape)
print('x_test.shape: ', x_test.shape)
print('y_test.shape: ', y_test.shape)

print('Training')
for i in range(epochs):
    print('Epoch', i + 1, '/', epochs)
    # Note that the last state for sample i in a batch will
    # be used as initial state for sample i in the next batch.
    # 请注意,批处理中的样本i的最后状态将作为下一批中的样本i的初始状态。
    # Thus we are simultaneously training on batch_size series with
    # lower resolution than the original series contained in data_input.
    # 因此,我们同时训练batch_size系列,比data_input中包含的原始序列具有更低的分辨率。
    # Each of these series are offset by one step and can be
    # extracted with data_input[i::batch_size].
    # 这些系列中的每一个都偏移一步,并且可以用data_input[i::batch_size]来提取。
    model_stateful.fit(x_train,
                       y_train,
                       batch_size=batch_size,
                       epochs=1,
                       verbose=1,
                       validation_data=(x_test, y_test),
                       shuffle=False)
    model_stateful.reset_states()

print('Predicting')
predicted_stateful = model_stateful.predict(x_test, batch_size=batch_size)

print('Creating Stateless Model...')
model_stateless = create_model(stateful=False)

print('Training')
model_stateless.fit(x_train,
                    y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, y_test),
                    shuffle=False)

print('Predicting')
predicted_stateless = model_stateless.predict(x_test, batch_size=batch_size)

# ----------------------------

print('Plotting Results')
plt.subplot(3, 1, 1)
plt.plot(y_test)
plt.title('Expected')
plt.subplot(3, 1, 2)
# drop the first "tsteps-1" because it is not possible to predict them
# since the "previous" timesteps to use do not exist
# 放弃第一个“TSTEPS-1”,因为不可能预测它们,因为以前使用的“时间步长”不存在。
plt.plot((y_test - predicted_stateful).flatten()[tsteps - 1:])
plt.title('Stateful: Expected - Predicted')
plt.subplot(3, 1, 3)
plt.plot((y_test - predicted_stateless).flatten())
plt.title('Stateless: Expected - Predicted')
plt.show()

代码执行

C:\ProgramData\Anaconda3\python.exe E:/keras-master/examples/lstm_stateful.py
Using TensorFlow backend.
*********************************
STATELESS LSTM WILL NOT CONVERGE
*********************************
Generating Data...
Input shape: (1000, 1)
Output shape: (1000, 1)
Input head:
          0
1 -0.084532
2  0.021696
3  0.079500
4  0.008981
5  0.040544
Output head:
          0
1 -0.035379
2 -0.031418
3  0.050598
4  0.044240
5  0.024763
Input tail:
             0
996   0.010251
997  -0.027833
998   0.003984
999   0.028471
1000 -0.057877
Output tail:
             0
996   0.025187
997  -0.008791
998  -0.011925
999   0.016227
1000 -0.014703
Plotting input and expected output
Creating Stateful Model...
x_train.shape:  (800, 1, 1)
y_train.shape:  (800, 1)
x_test.shape:  (200, 1, 1)
y_test.shape:  (200, 1)
Training
Epoch 1 / 10
Train on 800 samples, validate on 200 samples
Epoch 1/1

  1/800 [..............................] - ETA: 23:10 - loss: 8.5216e-04
  8/800 [..............................] - ETA: 2:57 - loss: 9.8957e-04
 13/800 [..............................] - ETA: 1:51 - loss: 8.4976e-04
 20/800 [..............................] - ETA: 1:13 - loss: 7.5000e-04
 29/800 [>.............................] - ETA: 52s - loss: 8.3234e-04
 35/800 [>.............................] - ETA: 43s - loss: 9.5878e-04
 42/800 [>.............................] - ETA: 37s - loss: 0.0010
 49/800 [>.............................] - ETA: 32s - loss: 0.0010
 57/800 [=>............................] - ETA: 28s - loss: 9.2958e-04
 64/800 [=>............................] - ETA: 25s - loss: 0.0010
 72/800 [=>............................] - ETA: 22s - loss: 0.0010
 80/800 [==>...........................] - ETA: 20s - loss: 0.0010
 88/800 [==>...........................] - ETA: 19s - loss: 9.7788e-04
 93/800 [==>...........................] - ETA: 18s - loss: 9.7039e-04
 98/800 [==>...........................] - ETA: 17s - loss: 0.0010
107/800 [===>..........................] - ETA: 16s - loss: 9.8123e-04
115/800 [===>..........................] - ETA: 15s - loss: 9.6177e-04
122/800 [===>..........................] - ETA: 14s - loss: 9.4016e-04
130/800 [===>..........................] - ETA: 13s - loss: 9.4474e-04
138/800 [====>.........................] - ETA: 13s - loss: 9.5132e-04
146/800 [====>.........................] - ETA: 12s - loss: 9.3231e-04
154/800 [====>.........................] - ETA: 12s - loss: 9.3191e-04
160/800 [=====>........................] - ETA: 11s - loss: 9.1047e-04
166/800 [=====>........................] - ETA: 11s - loss: 8.9187e-04
172/800 [=====>........................] - ETA: 11s - loss: 8.9586e-04
177/800 [=====>........................] - ETA: 10s - loss: 9.0390e-04
183/800 [=====>........................] - ETA: 10s - loss: 8.8498e-04
189/800 [======>.......................] - ETA: 10s - loss: 9.0886e-04
194/800 [======>.......................] - ETA: 10s - loss: 8.9147e-04
200/800 [======>.......................] - ETA: 9s - loss: 9.0357e-04
206/800 [======>.......................] - ETA: 9s - loss: 8.8145e-04
213/800 [======>.......................] - ETA: 9s - loss: 9.0894e-04
221/800 [=======>......................] - ETA: 9s - loss: 9.0075e-04
229/800 [=======>......................] - ETA: 8s - loss: 9.1960e-04
237/800 [=======>......................] - ETA: 8s - loss: 9.1508e-04
244/800 [========>.....................] - ETA: 8s - loss: 8.9347e-04
249/800 [========>.....................] - ETA: 8s - loss: 8.8506e-04
256/800 [========>.....................] - ETA: 7s - loss: 8.8218e-04
264/800 [========>.....................] - ETA: 7s - loss: 8.6197e-04
270/800 [=========>....................] - ETA: 7s - loss: 8.6244e-04
277/800 [=========>....................] - ETA: 7s - loss: 8.5799e-04
285/800 [=========>....................] - ETA: 7s - loss: 8.4244e-04
293/800 [=========>....................] - ETA: 6s - loss: 8.2892e-04
301/800 [==========>...................] - ETA: 6s - loss: 8.1710e-04
308/800 [==========>...................] - ETA: 6s - loss: 8.1387e-04
316/800 [==========>...................] - ETA: 6s - loss: 8.0678e-04
323/800 [===========>..................] - ETA: 6s - loss: 8.0260e-04
332/800 [===========>..................] - ETA: 6s - loss: 8.0745e-04
336/800 [===========>..................] - ETA: 6s - loss: 8.0657e-04
345/800 [===========>..................] - ETA: 5s - loss: 7.9895e-04
354/800 [============>.................] - ETA: 5s - loss: 7.8875e-04
362/800 [============>.................] - ETA: 5s - loss: 7.8790e-04
370/800 [============>.................] - ETA: 5s - loss: 7.9508e-04
378/800 [=============>................] - ETA: 5s - loss: 7.8973e-04
386/800 [=============>................] - ETA: 5s - loss: 7.8466e-04
394/800 [=============>................] - ETA: 4s - loss: 7.9312e-04
402/800 [==============>...............] - ETA: 4s - loss: 7.8996e-04
411/800 [==============>...............] - ETA: 4s - loss: 7.7805e-04
420/800 [==============>...............] - ETA: 4s - loss: 7.6865e-04
429/800 [===============>..............] - ETA: 4s - loss: 7.6231e-04
437/800 [===============>..............] - ETA: 4s - loss: 7.5908e-04
445/800 [===============>..............] - ETA: 4s - loss: 7.4860e-04
452/800 [===============>..............] - ETA: 3s - loss: 7.4074e-04
460/800 [================>.............] - ETA: 3s - loss: 7.3031e-04
468/800 [================>.............] - ETA: 3s - loss: 7.2087e-04
476/800 [================>.............] - ETA: 3s - loss: 7.2082e-04
483/800 [=================>............] - ETA: 3s - loss: 7.2121e-04
491/800 [=================>............] - ETA: 3s - loss: 7.1069e-04
499/800 [=================>............] - ETA: 3s - loss: 7.0494e-04
507/800 [==================>...........] - ETA: 3s - loss: 7.0049e-04
515/800 [==================>...........] - ETA: 3s - loss: 6.9607e-04
523/800 [==================>...........] - ETA: 2s - loss: 6.9888e-04
531/800 [==================>...........] - ETA: 2s - loss: 6.9627e-04
539/800 [===================>..........] - ETA: 2s - loss: 7.0178e-04
545/800 [===================>..........] - ETA: 2s - loss: 6.9792e-04
551/800 [===================>..........] - ETA: 2s - loss: 6.9527e-04
556/800 [===================>..........] - ETA: 2s - loss: 6.9119e-04
562/800 [====================>.........] - ETA: 2s - loss: 6.8701e-04
568/800 [====================>.........] - ETA: 2s - loss: 6.8762e-04
575/800 [====================>.........] - ETA: 2s - loss: 6.8512e-04
583/800 [====================>.........] - ETA: 2s - loss: 6.8444e-04
590/800 [=====================>........] - ETA: 2s - loss: 6.8151e-04
595/800 [=====================>........] - ETA: 2s - loss: 6.8386e-04
599/800 [=====================>........] - ETA: 2s - loss: 6.8208e-04
605/800 [=====================>........] - ETA: 2s - loss: 6.7642e-04
610/800 [=====================>........] - ETA: 1s - loss: 6.7531e-04
616/800 [======================>.......] - ETA: 1s - loss: 6.7193e-04
622/800 [======================>.......] - ETA: 1s - loss: 6.6764e-04
628/800 [======================>.......] - ETA: 1s - loss: 6.6446e-04
635/800 [======================>.......] - ETA: 1s - loss: 6.6137e-04
642/800 [=======================>......] - ETA: 1s - loss: 6.5629e-04
649/800 [=======================>......] - ETA: 1s - loss: 6.5130e-04
656/800 [=======================>......] - ETA: 1s - loss: 6.4967e-04
664/800 [=======================>......] - ETA: 1s - loss: 6.4691e-04
672/800 [========================>.....] - ETA: 1s - loss: 6.4803e-04
680/800 [========================>.....] - ETA: 1s - loss: 6.4319e-04
688/800 [========================>.....] - ETA: 1s - loss: 6.4213e-04
696/800 [=========================>....] - ETA: 1s - loss: 6.3881e-04
704/800 [=========================>....] - ETA: 0s - loss: 6.3280e-04
713/800 [=========================>....] - ETA: 0s - loss: 6.3063e-04
721/800 [==========================>...] - ETA: 0s - loss: 6.2775e-04
727/800 [==========================>...] - ETA: 0s - loss: 6.2472e-04
732/800 [==========================>...] - ETA: 0s - loss: 6.2265e-04
739/800 [==========================>...] - ETA: 0s - loss: 6.1951e-04
745/800 [==========================>...] - ETA: 0s - loss: 6.1688e-04
750/800 [===========================>..] - ETA: 0s - loss: 6.1430e-04
757/800 [===========================>..] - ETA: 0s - loss: 6.1158e-04
764/800 [===========================>..] - ETA: 0s - loss: 6.0985e-04
772/800 [===========================>..] - ETA: 0s - loss: 6.0831e-04
780/800 [============================>.] - ETA: 0s - loss: 6.0951e-04
788/800 [============================>.] - ETA: 0s - loss: 6.0644e-04
795/800 [============================>.] - ETA: 0s - loss: 6.0409e-04
800/800 [==============================] - 8s 10ms/step - loss: 6.0417e-04 - val_loss: 5.2337e-04
Epoch 2 / 10
Train on 800 samples, validate on 200 samples
Epoch 1/1

  1/800 [..............................] - ETA: 4s - loss: 9.3058e-05
  8/800 [..............................] - ETA: 5s - loss: 3.6670e-04

785/800 [============================>.] - ETA: 0s - loss: 8.4495e-04
794/800 [============================>.] - ETA: 0s - loss: 8.4709e-04
800/800 [==============================] - 6s 8ms/step - loss: 8.5014e-04 - val_loss: 9.6398e-04
Predicting
Plotting Results

Process finished with exit code 0

Keras详细介绍

英文:https://keras.io/

中文:http://keras-cn.readthedocs.io/en/latest/

实例下载

https://github.com/keras-team/keras

https://github.com/keras-team/keras/tree/master/examples

完整项目下载

方便没积分童鞋,请加企鹅452205574,共享文件夹。

包括:代码、数据集合(图片)、已生成model、安装库文件等。


猜你喜欢

转载自blog.csdn.net/wyx100/article/details/80821721