【python量化】挖掘股价中的图关系:基于图注意力网络的股价预测模型

f135eb2240577efe3c9f84e77ae42d2c.png

写在前面

近些年,图神经网络在时间序列预测领域发挥了重要的作用。其中,图注意力网络(GAT)是一种基于注意力机制的图神经网络,能够捕捉图结构数据中节点之间的复杂关系,从而在许多领域中取得了突出的性能。在本文中,我们利用Pytorch以及PyG(PyTorch Geometric 是一个强大而灵活的图神经网络库)框架,实现一个将 GAT 应用于股票价格预测的简单例子

1

前言

随着金融市场的复杂性不断增加,对股价的预测成为了一项巨大的挑战。传统的时间序列分析方法虽然在某些场景下有效,但往往无法捕捉市场中的复杂相互作用和隐藏模式。为了解决这一问题,本文采用了一种全新的视角:将股价时间序列转化为图结构,通过图注意力网络(GAT)以图的视角来建模和分析。图注意力网络是一种强大的图神经网络结构,通过引入注意力机制,能够灵活捕捉图中节点间的相互关系。在股价预测的场景中,股票之间的相互作用和依赖关系可以被自然地建模为图结构,其中股票作为节点,它们之间的相互作用作为边。借助 PyTorch Geometric(PyG)这一先进的图神经网络库,本文展示了如何构建、训练和评估 GAT 模型来预测股票价格。

项目的核心是将股价时间序列数据转化为图结构,然后利用 GAT 的强大表征能力进行分析和预测。本项目旨在作为一个简单演示,展示如何使用图注意力网络来分析股价时间序列,仅作为学习和研究之用,不应用于实际的投资决策

2

环境配置

本地环境:

Python 3.7
IDE:Pycharm

库版本:

numpy 1.18.1
pandas 1.0.3 
torch-geometric 2.0.2
matplotlib 3.2.1
torch 1.10.1
tushare 1.2.60

3

代码实现

总体设计

该项目通过四个主要文件组织,展示了如何使用图注意力网络(GAT)进行股票价格预测。从获取和处理股票数据到构建和训练深度学习模型,再到最终的预测阶段,整个流程被清晰地实现。每个文件都专注于特定的任务,共同构建了一个完整的股价预测解决方案:

1. data_hander.py 数据处理模块

此模块负责从 Tushare 获取股票的收盘价并保存到文件中。主要功能包括:获取指定股票代码和日期范围的收盘价数据。保存和加载收盘价数据。绘制股价曲线图。构建股票间的邻接矩阵,用于图模型。

2. model.py - 模型模块

在这个模块中,定义了基于图注意力网络(GAT)的深度学习模型。主要组成部分包括构建GATPredictor 类,用于构建 GAT 模型,设置不同的隐藏层大小和注意力头数量。定义了模型的前向传播过程。

3. train.py - 训练模块

此文件包括了数据预处理和模型训练的相关函数:数据归一化和分割为训练和测试集。使用滑动窗口方法处理时间序列数据。定义训练和预测的主要函数。

4. main.py - 主程序模块

作为项目的入口点,此文件组织和调用上述三个文件中的功能:定义要预测的股票代码和日期范围。调用数据处理、模型构建、训练和预测的相关函数。定义了整个项目的主要流程和执行逻辑。

数据处理模块

数据处理模块负责从外部源获取股票价格数据,并进行适当的预处理以供模型使用。其中,需要将tushare API token替换到代码中的Your Token处。

def fetch_close_prices(stocks, start_date, end_date, file_name='close_prices.csv'):
    if os.path.exists(file_name):
        close_prices = pd.read_csv(file_name).values
    else:
        ts.set_token('Your Token')
        pro = ts.pro_api()
        close_prices_list = []


        for stock in stocks:
            df = pro.daily(ts_code=stock, start_date=start_date, end_date=end_date)
            close_prices_list.append(df['close'].values)


        close_prices = np.column_stack(close_prices_list)
        pd.DataFrame(close_prices).to_csv(file_name, index=False)


    return close_prices

为了使用图注意力网络(GAT),需要将股票数据转化为图结构。这里将每只股票的收盘价的Pearson系数作为构建邻接矩阵来表示股票之间的相关性的依据,并设定阈值来确定是否具有连边。

def build_adjacency_matrix(close_prices, threshold=0.5):
    N = close_prices.shape[1]
    adj_matrix = np.zeros((N, N))


    for i in range(N):
        for j in range(N):
            correlation = np.corrcoef(close_prices[:, i], close_prices[:, j])[0, 1]
            adj_matrix[i, j] = 1 if abs(correlation) > threshold else 0
    return adj_matrix

模型模块

这个模块负责定义和实现一个简单的基于图注意力网络的股价预测模型。通过构建包括图注意力层、隐藏层和输出层的结构,来实现模型的前向传递。其中,由于PyG的GAT层只能接受二维的数据,所以为了提升并行运算速度,这里将一个batch的数据合成一张大图进行运算,最后通过view转换后再通过linear层进行输出预测值。

class GATPredictor(nn.Module):
    def __init__(self, node_features, node_nums, hidden_size=32, num_heads=1):
        super(GATPredictor, self).__init__()
        self.node_nums = node_nums
        self.gat1 = GATConv(node_features, hidden_size, heads=num_heads)
        self.gat2 = GATConv(hidden_size * num_heads, hidden_size, concat=False)
        self.out = nn.Linear(hidden_size, 1)


    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        batch_size = torch.max(batch).item() + 1
        x = torch.relu(self.gat1(x, edge_index))
        x = torch.relu(self.gat2(x, edge_index))
        x = x.view(batch_size, self.node_nums, -1)
        x = self.out(x)
        return x.squeeze(2)

训练模块

训练模块主要负责处理数据到模型的指定输入形式以及模型训练的逻辑。其中实现了一些对数据进行归一化处理以及划分滑动窗口的函数,以及模型训练和测试的函数。

def normalize_and_split(closing_prices, test_size=0.2):
    scaler = MinMaxScaler()
    normalized_data = scaler.fit_transform(closing_prices)
    train_data, test_data = train_test_split(normalized_data, test_size=test_size, shuffle=False)
    return train_data, test_data, scaler


def sliding_window(data, window_size):
    windows = []
    for i in range(len(data) - window_size):
        x = data[i:i + window_size]
        y = data[i + window_size]
        windows.append((x, y))
    return window

构建节点之间的连边时,每个股票作为一个节点,其窗口的数据作为节点特征,并且对每个滑动窗口构建相同的图关系,最后将其转换为PyG的指定输入数据形式。

def create_graph_data(windows, adj_matrix):
    graph_data = []
    edge_index = torch.tensor(np.where(adj_matrix != 0), dtype=torch.long)


    for window in windows:
        x, y = window
        x_tensor = torch.tensor(x.T, dtype=torch.float)
        y_tensor = torch.tensor(y, dtype=torch.float)  


        data = Data(x=x_tensor, y=y_tensor, edge_index=edge_index)
        graph_data.append(data)


    return graph_data


def train(model, train_loader, optimizer, criterion, epochs):
    model.train()
    for epoch in range(epochs):
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch)
            y = batch.y.view(out.size())
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
        print(f'Epoch: {epoch+1}, Loss: {loss.item()}'

主程序模块

这个模块根据给定的股票代码和日期范围获取历史收盘价数据,这里简单取了上证2022年的几只股票。然后通过滑动窗口方法和邻接矩阵将其转化为图数据格式。接下来,初始化并训练一个基于图注意力网络的预测模型,最后使用该模型进行预测并将预测结果与真实值进行可视化对比。

# 股票代码和日期范围
stocks = ['000001.SZ', '000002.SZ', '000006.SZ', '000005.SZ', '000008.SZ', '000009.SZ']
start_date = "2022-01-01"
end_date = "2023-01-01"
window_size = 10


# 获取收盘价并创建邻接矩阵
close_prices = fetch_close_prices(stocks, start_date, end_date)
adj_matrix = build_adjacency_matrix(close_prices)


# 数据归一化和划分
train_data, test_data, scaler = normalize_and_split(close_prices)


# 滑动窗口数据准备
train_windows = sliding_window(train_data, window_size)
test_windows = sliding_window(test_data, window_size)


# 转换为图数据
train_dataset = create_graph_data(train_windows, adj_matrix)
test_dataset = create_graph_data(test_windows, adj_matrix)


print(adj_matrix)


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


# 创建模型
model = GATPredictor(node_features=window_size, node_nums=len(stocks))


# 优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()


# 训练
train(model, train_loader, optimizer, criterion, epochs=50)


# 预测
predictions, groundtruth = predict(model, test_loader, scaler)


# 可视化预测结果
plot_results(predictions, groundtruth, stocks)

运行效果

接下来对模型进行训练跟测试,经过多个epoch的迭代,mse逐渐收敛。

Epoch: 1, Loss: 0.3193877637386322
Epoch: 2, Loss: 0.21741190552711487
Epoch: 3, Loss: 0.16308942437171936
Epoch: 4, Loss: 0.09306434541940689
Epoch: 5, Loss: 0.052657563239336014
Epoch: 6, Loss: 0.02670634351670742
Epoch: 7, Loss: 0.019804660230875015
Epoch: 8, Loss: 0.016364283859729767
Epoch: 9, Loss: 0.017054198309779167
Epoch: 10, Loss: 0.01848340593278408
...
Epoch: 44, Loss: 0.009028978645801544
Epoch: 45, Loss: 0.013259928673505783
Epoch: 46, Loss: 0.01599966734647751
Epoch: 47, Loss: 0.011484017595648766
Epoch: 48, Loss: 0.010215471498668194
Epoch: 49, Loss: 0.01579277403652668
Epoch: 50, Loss: 0.011635522358119488

之后对多个股票的预测结果进行可视化,可以看出有些股票的趋势拟合效果较好。由于这六只股票只是简单的选取,所以其关联性可能不会很强,所以也可以选择多个相同板块的股票进行实验。除此之外,模型的结构也很简单,并没有引入更多的股价特征,所以可以进一步改进的点还有很多。

f528cb1297400fff18ebf067d1a7e70a.png

4

总结


在现代金融领域,随着大数据和机器学习技术的日益成熟,传统的股价预测方法正面临着前所未有的挑战和机遇。传统的股价预测往往依赖于单一股票的历史数据,忽略了股票间的互动和影响。这篇文章探讨了如何利用图注意力网络 (GAT) 挖掘股票之间的潜在关系,为股价预测提供了新的视角。其中将邻接矩阵与滑动窗口方法相结合,将时间序列的股价数据转化为图数据格式,从而捕捉股票间的复杂相互作用。本文中的实验只是一个简单的demo,还存在许多可以改进的空间,感兴趣的读者可以进一步研究。

本文内容仅仅是技术探讨和学习,并不构成任何投资建议。

获取完整代码与数据以及其他历史文章完整源码与数据可加入《人工智能量化实验室》知识星球。

往期推荐阅读

WWW 2023 | 量化交易相关论文(附论文链接)

KDD 2023 | 量化交易相关论文(附论文链接)

AAAI 2022 | 量化交易相关论文(附论文链接)

IJCAI 2022 | 量化交易相关论文(附论文链接)

WWW 2022 | 量化交易相关论文(附论文链接)

KDD 2022 | 量化交易相关论文(附论文链接)

解读:ChatGPT在股票市场预测方面的应用

解读:通过挖掘概念间共享信息,实现股票趋势预测的图模型框架

解读:机器学习预测收益模型应该采取哪种度量指标

解读:基于订单流、技术分析与神经网络的期货短期走势预测模型

【python量化】基于backtrader的深度学习模型量化回测框架

【python量化】将Transformer模型用于股票价格预测

【python量化】搭建一个CNN-LSTM模型用于股票价格预测

【python量化】用python搭建一个股票舆情分析系统

【python量化】将Informer用于股价预测

【python量化】将DeepAR用于股票价格多步概率预测

9992a579714fc702da17a79279e8aa84.png

《人工智能量化实验室》知识星球

fd428cefb5057f46318f072a6b0b706b.png

加入人工智能量化实验室知识星球,您可以获得:(1)定期推送最新人工智能量化应用相关的研究成果,包括高水平期刊论文以及券商优质金融工程研究报告,便于您随时随地了解最新前沿知识;(2)公众号历史文章Python项目完整源码;(3)优质Python、机器学习、量化交易相关电子书PDF;(4)优质量化交易资料、项目代码分享;(5)跟星友一起交流,结交志同道合朋友。(6)向博主发起提问,答疑解惑。

8cf1db2f2f113a3b63d329aae0a835f5.png

猜你喜欢

转载自blog.csdn.net/FrankieHello/article/details/132929137
今日推荐