Video Teaching: Traffic Flow Prediction LSTM Practical Detailed Teaching_哔哩哔哩_bilibili
The result display:
Full code:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch import nn
from torch.autograd import Variable
data_csv = pd.read_csv("上海中山公园地铁客流2015年数据.csv",encoding = 'gb2312')
print(data_csv.head())
plt.figure(figsize=(12,4))
plt.plot(data_csv[0:300])
plt.show()
# 创建训练和测试LSTM模型的数据集,是通过前面测试的15min时间粒度的客流量来预测当前时间粒度的客流量,我们令前2个时间粒度的客流数据是输入,对应代码中的step=2,
# 把当前时间粒度的客流数据作为输出,划分数据