(tensorflow)使用循环神经网络模型预测正弦函数

一:代码

# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
import matplotlib as mpl
mpl.use('Agg') #设置只保存绘制图片,不以窗口形式显示
from matplotlib import pyplot as plt

from tensorflow.contrib import rnn

HIDDEN_SIZE = 30
NUM_LAYERS = 2
TIME_STEPS = 10 #循环神经网络的截断长度
LEARNING_RATE = 1.0
BATCH_SIZE = 32
TRAIN_STEPS = 10000

TRAINING_EXAMPLES = 10000
TESTING_EXAMPLES = 1000
SAMPLE_GAP = 0.01      #采样间隔

learn = tf.contrib.learn

def generate_data(seq):
    X = []
    y = []
    #序列的第i项和后面的TIME_STEPS-1项合在一起作为输入;第i+TIME_STEPS项作为输出。
    #即用sin函数前面的TIME_STEPS个点信息,预测第i+TIME_STEPS个点的函数值
    for i in range(len(seq) - TIME_STEPS - 1):
        X.append([seq[i:i+TIME_STEPS]])
        y.append([seq[i+TIME_STEPS]])
    return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)


def lstm_model(X, y):
    stacked_rnn = []
    #使用多层LSTMCell结构
    for iiLyr in range(3):
        stacked_rnn.append(tf.nn.rnn_cell.LSTMCell(num_units=HIDDEN_SIZE, state_is_tuple=True))

    cell = tf.nn.rnn_cell.MultiRNNCell(cells=stacked_rnn, state_is_tuple=True)
    x_ = tf.unstack(X, axis=1)
    # 将多层LSTMCell结构连接成RNN网络并计算前向传播结果
    output, _ = rnn.static_rnn(cell, x_, dtype=tf.float32)
    #这里只关注最后一个时刻的输出结果,该结果为下一时刻的预测值
    output = output[-1]

    prediction, loss = learn.models.linear_regression(output, y)
    train_op = tf.contrib.layers.optimize_loss(
        loss,
        tf.contrib.framework.get_global_step(),
        optimizer='Adagrad', learning_rate=0.1
    )

    return prediction, loss, train_op

regressor = learn.Estimator(model_fn=lstm_model)

test_start = TRAINING_EXAMPLES * SAMPLE_GAP
test_end = (TRAINING_EXAMPLES + TESTING_EXAMPLES) * SAMPLE_GAP

#用正弦函数生成训练,预测数据集合
train_x, train_y = generate_data(np.sin(np.linspace(0, test_start, TRAINING_EXAMPLES, dtype=np.float32)))
test_x, test_y = generate_data(np.sin(np.linspace(test_start, test_end, TESTING_EXAMPLES, dtype=np.float32)))

regressor.fit(train_x, train_y, batch_size=BATCH_SIZE, steps=TRAIN_STEPS)

predicted = [[pred] for pred in regressor.predict(test_x)]

#使用rmse作为预测指标
rmse = np.sqrt(((predicted - test_y) ** 2).mean(axis=0))
print('mean square error is: %f' % rmse[0])

#对预测的sin函数曲线进行绘图
fig = plt.figure()
plot_predicted = plt.plot(predicted, label='predicted')
plot_test = plt.plot(test_y, label='real_sin')
plt.legend()
fig.savefig('sin.png')

二:结果

mean square error is: 0.002049

这里写图片描述

图中可以看出,预测的sin曲线与正弦曲线基本重合

猜你喜欢

转载自blog.csdn.net/zzldm/article/details/82491700