使用GAT训练和测试EEG公开的SEED数据集

下面所有博客是个人对EEG脑电的探索,项目代码是早期版本不完整,需要完整项目代码和资料请私聊。


主要内容:
1、在EEG(脑电)项目中,使用图神经网络对脑电进行处理,具体包括baseline的GCN图架构、复现baseline论文的RGNN架构、注意力机制图架构、Transformer图架构、注重效率的simple图架构等,进行实验和对比。
2、学习图神经网络相关的资料。是学习图神经网络的一个完整项目;



数据集
1、脑电项目探索和实现(EEG) (上):研究数据集选取和介绍SEED
相关论文阅读分析:
1、EEG-SEED数据集作者的—基线论文阅读和分析
2、图神经网络EEG论文阅读和分析:《EEG-Based Emotion Recognition Using Regularized Graph Neural Networks》
3、EEG-GNN论文阅读和分析:《EEG Emotion Recognition Using Dynamical Graph Convolutional Neural Networks》
4、论文阅读和分析:Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification
5、论文阅读和分析:《DeepGCNs: Can GCNs Go as Deep as CNNs?》
6、论文阅读和分析: “How Attentive are Graph Attention Networks?”
7、论文阅读和分析:Simplifying Graph Convolutional Networks

8、论文阅读和分析:LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation
9、图神经网络汇总和总结
相关实验和代码实现:
1、用于图神经网络的脑电数据处理实现_图神经网络 脑电
2、使用GCN训练和测试EEG的公开SEED数据集
3、使用GAT训练和测试EEG公开的SEED数据集
4、使用SGC训练和测试SEED数据集
5、使用Transformer训练和测试EEG的公开SEED数据集_eeg transformer
6、使用RGNN训练和测试EEG公开的SEED数据集
辅助学习资料:
1、官网三个简单Graph示例说明三种层次的应用_graph 简单示例
2、PPI数据集示例项目学习图神经网络
3、geometric库的数据处理详解
4、NetworkX的dicts of dicts以及解决Seven Bridges of Königsberg问题
5、geometric源码阅读和分析:MessagePassin类详解和使用
6、cora数据集示例项目学习图神经网络
7、Graph 聚合
8、QM9数据集示例项目学习图神经网络
9、处理图的开源库

部分代码如下:

# -*- coding: utf-8 -*-
#
# Copyright (C) 2022 Emperor_Yang, Inc. All Rights Reserved 
#
# @CreateTime    : 2023/2/19 22:13
# @Author        : Emperor_Yang 
# @File          : ECG_GAT.py
# @Software      : PyCharm


import torch
import torch.nn.functional as F
from easydict import EasyDict
from torch_geometric.nn import GATConv, global_add_pool
from torch_geometric.data import DataLoader
from data_process.seed_loader_gnn_memory import SeedGnnMemoryDataset

config = EasyDict()
config.learn_rate = 0.01
config.epoch = 20
config.note_feature_dim = 5
config.note_num = 62
config.hidden_channels = 16
config.class_num = 3
config.hidden_layers = 1
config.batch_size = 16
config.max_loss_increase_time = 3
config.head_num = 2


class EEG_GAT(torch.nn.Module):
    """
    GCN handle ECG
    """

    def __init__(self, in_channels, hidden_channels, out_channels):
        super(EEG_GAT, self).__init__()

        self.conv_s = torch.nn.ModuleList()
        # output dim is heads * out_channel
        self.conv_s.append(GATConv(in_channels, hidden_channels, config.head_num, dropout=0.6))
        for i in range(config.hidden_layers - 1):
            self.conv_s.append(GATConv(hidden_channels * config.head_num, hidden_channels, config.head_num, dropout=0.6))

        self.fc1 = torch.nn.Linear(hidden_channels , out_channels)

    def forward(self, x, edge_index, index, edge_weight=None):
        """
        forward
        :param index:
        :param x:note feature
        :param edge_index:edge pair
        :param edge_weight: edge feature
        :return:
        """
        for conv in self.conv_s:
            x = conv(x, edge_index, edge_weight).relu()
        x = global_add_pool(x, index)
        x = self.fc1(x)
        return x


model = EEG_GAT(config.note_feature_dim, config.hidden_channels, config.class_num)
data_set = SeedGnnMemoryDataset(root='../data/SEED/', processed_file='1_20131027.pt')
train_data_set = data_set[: int(0.8 * data_set.len())]
test_data_set = data_set[int(0.8 * data_set.len()):]
train_data_loader = DataLoader(train_data_set, batch_size=config.batch_size, shuffle=True)
test_data_loader = DataLoader(test_data_set, batch_size=config.batch_size, shuffle=False)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learn_rate)
criterion = torch.nn.CrossEntropyLoss()


def train():
    loss_sum = 0
    data_size = 0
    for mini_batch in train_data_loader:
        if mini_batch.num_graphs == config.batch_size:
            data_size += mini_batch.num_graphs
            model.train()
            optimizer.zero_grad()
            out = model(mini_batch.x, mini_batch.edge_index, mini_batch.batch)
            loss = criterion(out, mini_batch.y)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item() / mini_batch.num_graphs
    return loss_sum / data_size


def test():
    count = 0
    data_size = 0
    for mini_batch in test_data_loader:
        if mini_batch.num_graphs == config.batch_size:
            out = model(mini_batch.x, mini_batch.edge_index, mini_batch.batch)
            predict = torch.argmax(out, dim=1)
            count += int(predict.eq(mini_batch.y).sum())
            data_size += mini_batch.num_graphs
    print("Test Accuracy:{}%".format(count / data_size * 100))


if __name__ == '__main__':
    loss_increase_time = 0
    last_lost = 1
    for epoch in range(config.epoch):
        avg_loss = train()
        print("epoch:{}, loss:{}".format(epoch+1, avg_loss))
        if avg_loss > last_lost:
            loss_increase_time += 1
        else:
            last_lost = avg_loss
        # 如果连续增加loss大于config.max_loss_increase_time,则停止训练
        if loss_increase_time > config.max_loss_increase_time:
            break
    test()

猜你喜欢

转载自blog.csdn.net/KPer_Yang/article/details/129074872