8、Pseudo_Siamese_Network:用于BCI IV2a 4分类

前言:

毕业了,还是关注着孪生模型的发展。近几年该模型仍是热点,发表的论文较少,但从去年开始,文章数量稳步上升,随着SCI文章的增多,学术界对它知道的人也越来越多。最近看到一23年1月发表的一篇顶刊文章:《A Siamese Network-Based Method for Improving the Performance of Sleep Staging with Single-Channel EEG》,北理自动化学院和北京协和一起搞得一个模型,搞了两个模型,一个是Siamese CNNs,另一个是Siamese AEs,分类准确率较高,87.2%,Kappa值0.81,相当可以,作者说这两模型用于同一睡眠阶段的两个脑电信号之间的编码,并可以编码不同睡眠阶段的脑电图时间的差异。两个模型如下:

图1 Siamese CNNs

图2 Siamese AEs

模型详细结构这里不讲,感兴趣的大家可以自行查看,这里主要说一下这两个模型在孪生模型搭建中的启发:

对于EEG数据(注意,是对于EEG信号,而不是图片):

1、较小的卷积尺寸适合提取时间特征

2、较大的卷积尺寸适合提取频率特征

3、不同的卷积尺寸可以捕获不同频段的信息

4、池化尺寸,偶数较好

Siamese CNNS:也就是把孪生两个输入通道,每个通道设置两个不同的CNN模型,这里设计的就是伪孪生神经网络了,不同大小的卷积交叉使用,最大限度提取EEG的时间空间特征信息,并且卷积叠加使用,是参照VGG模型研发的。
Siamese AEs:里面加了编码和解码器,编码就是4层cnn+1层丢弃,解码使用了1维反转卷积层:ConvTranspose1D。然后还加了L1、L2惩罚范数,综合性能该模型比Siamese CNNs高2%左右。

Pseudo_Siamese_Network:用于BCI IV2a 4分类

受该文章启发,我自己搭了一个伪孪生模型,用于处理BCI IV2a数据,作用于4分类,预测人类大脑运动想象。此外,最近学习了特征融合的一些知识,人们苦恼于模型层特征融合的收敛问题,比如我在融合模型的两个分支中,一个分支我建了一个cnn,另一个我搭了LSTM,两个分支到最后决策层出来的数据合在一起差别较大,融合模型不收敛了,头疼,所以我觉得孪生神经网络是一个很好的解决方式,训练出来的话,它可以解决模型层面+决策层面的特征融合问题,后面我会尝试做做。这次依旧拿BCI这个冤大头数据开刀试试,全部代码如下:
 

import torchvision.transforms as transforms
from torch import optim
from torch.utils.data import DataLoader,Dataset
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from sklearn import metrics
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
import torchvision.datasets
from sklearn.preprocessing import MinMaxScaler
import mne,glob,os,re,torch,sklearn,warnings
import numpy as np
from tqdm import tqdm
from sklearn.preprocessing import OneHotEncoder
from sklearn import preprocessing
import torch
import torchvision.transforms as transforms
transf = transforms.ToTensor()
from sklearn.preprocessing import StandardScaler
from model import *
import pandas as pd
import os
from model import *
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.metrics import accuracy_score
import sklearn
from sklearn import metrics
import logging


#帮助类
class Config:
    batch_size = 32
    lr = 0.01 
    epochs = 100


class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=0.5):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - 
                                       euclidean_distance, min=0.0), 2))
        
        return loss_contrastive


class pseudo_siamese_network(nn.Module):
    def __init__(self):
        super(pseudo_siamese_network,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1,8,(2,10)),
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.MaxPool2d(8,8),
            nn.Dropout(p=0.5),
            
            nn.Conv2d(8,16,(2,1)),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16,32,(1,20)),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32,32,(1,1)),
            nn.ReLU())
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(1,8,(10,100)),
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.MaxPool2d(4,4),
            nn.Dropout(p=0.5),
            
            nn.Conv2d(8,16,(2,1)),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16,32,(1,50)),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32,32,(1,1)),
            nn.ReLU())
        
        self.fc1 = nn.Sequential(
            nn.Linear(3328,4096),
            nn.ReLU(),
            nn.Linear(4096,4),
            nn.Softmax())
        
        self.fc2 = nn.Sequential(
            nn.Linear(11264,4096),
            nn.ReLU(),
            nn.Linear(4096,4),
            nn.Softmax())
        
    def forward_one(self,x1):
        x1 = self.conv1(x1)
        x1 = x1.contiguous().view(x1.size()[0],-1)
        x1 = self.fc1(x1)
        return x1
    
    def forward_two(self,x2):
        x2 = self.conv2(x2)
        x2 = x2.contiguous().view(x2.size()[0],-1)
        x2 = self.fc2(x2)
        return x2
    
    def forward(self,x1,x2):
        output1 = self.forward_one(x1)
        output2 = self.forward_two(x2)
        return output1,output2
    
class feature_dataset(Dataset):
    def __init__(self,file_path,target_path,transform =None):
        self.file_path = file_path
        self.target_path = target_path

        self.data = self.parse_data_file(file_path)
        self.target = self.parse_target_file(target_path)
        self.transform = transform

    def parse_data_file(self,file_path):   
        data = pd.read_csv(file_path,header=None)
        data = np.array(data,dtype=np.float32)
        data = torch.tensor(data)
        num_sub = len(data)
        scaler =StandardScaler().fit(data)
        data = scaler.transform(data)
        data = data.reshape(num_sub,22,1000)
        data = torch.tensor(data)
        #data = data.transpose(2,1)
        data = data.unsqueeze(1)
        '''
        for row in data:
          row = row.strip(" ").split(' ')
          row = np.array([np.float32(cell) for cell in row])
          #data = torch.ternsor(row)
        '''
        return np.array(data,dtype=np.float32)

    def parse_target_file(self,target_path):
            
        target = pd.read_csv(target_path,header=None)
        target = np.array(target,dtype=np.float32)
        encoder = preprocessing.OneHotEncoder(handle_unknown='ignore')
        x_hot = np.array(target)
        x_hot = x_hot.reshape(-1,1)
        encoder.fit(x_hot)
        x_oh = encoder.transform(x_hot).toarray()
        print(x_oh,x_oh.shape)
        d=transf(x_oh)  
        x_hot_label = torch.argmax(d, dim=2).long()
        print(x_hot_label.shape)
        label = x_hot_label
        label = label.transpose(1,0)
        target = torch.squeeze(label)
        return target

    def __len__(self):
        return len(self.data) 

    def __getitem__(self, index):
        x1 = self.data[index,:]
        index2 = np.random.choice(len(self.data))
        x2 = self.data[index2,:]
        target = self.target[index]
         
        return x1,x2,target 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
config = Config()

train_transforms = transforms.Compose([transforms.ToTensor()])
train_dataset = feature_dataset(file_path = r'C:\MI-EEG-A01T.csv',
                                target_path=r'C:\EtiquetasA01T.csv')
train_dataloader = DataLoader(train_dataset,shuffle=True,batch_size=config.batch_size)

val_transforms = transforms.Compose([transforms.ToTensor()])
val_dataset = feature_dataset(file_path =r'C:\MI-EEG-A01E.csv',
                              target_path=r'C:\EtiquetasA01E.csv')
val_dataloader = DataLoader(val_dataset,batch_size=config.batch_size)  

model = pseudo_siamese_network().to(device)
optimizer = optim.SGD(model.parameters(),lr=config.lr)
#criterion = torch.nn.CrossEntropyLoss()
criterion = ContrastiveLoss()
loss_fn = criterion.to(device) 


    
def show_plot(accuracy_history,loss_history,test_accuracy):
    plt.figure(figsize=(20,10))
    #fig2
    plt.subplot(121)
    plt.plot(loss_history,marker=".",color="c")
    plt.title('train loss')
    #fig3
    plt.subplot(122)
    plt.plot(accuracy_history,marker="o",label="train_acc") #plt.plot(x,y)定义x,y轴数据,定义颜色,标记型号,大小等
    plt.plot(test_accuracy, marker='o', label="test_acc")
    plt.title("ACC")
    plt.legend(loc="best")
    plt.savefig('acc_loss.png')
    plt.show()
    
def plot_recall(epoch_list,recall1,recall2,recall3,recall4):
    plt.figure(figsize=(15,8)) 
    plt.plot(epoch_list,recall1, color='purple', label='Back1_Recall',marker=".")
    plt.plot(epoch_list,recall2,color='c',label="Back2_Recall",marker=".")
    plt.plot(epoch_list,recall3,color='g',label="Back3_Recall",marker=".")
    plt.plot(epoch_list,recall4,color='m',label="Back4_Recall",marker=".")
    plt.title('Recall during test')
    plt.xlabel('Epoch')
    plt.ylabel('Recall_Vales')
    plt.legend()
    plt.savefig("recall.jpg")
    plt.show()

def plot_precision(epoch_list,precision1,precision2,precision3,precision4):
    plt.figure(figsize=(15,8))
    plt.plot(epoch_list,precision1, color='black', label='Back1_Precision',marker="o")
    plt.plot(epoch_list,precision2, color='b', label='Back2_Precision',marker="o")
    plt.plot(epoch_list,precision3, color='m', label='Back3_Precision',marker="o")
    plt.plot(epoch_list,precision4, color='c', label='Back4_Precision',marker="o")
    plt.xlabel('Epoch')
    plt.ylabel('Precision_Vales')
    plt.title('Precision during test')
    plt.legend()
    plt.savefig("precision.jpg")
    
    plt.show()

def plot_f1(epoch_list,f1_1,f1_2,f1_3,f1_4):
    plt.figure(figsize=(15,8))
    plt.plot(epoch_list,f1_1, color='yellow', label='Back1_F1',marker="^")
    plt.plot(epoch_list,f1_2, color='g', label='Back2_F1',marker="^")
    plt.plot(epoch_list,f1_3, color='b', label='Back3_F1',marker="^")
    plt.plot(epoch_list,f1_4, color='m', label='Back4_F1',marker="^")
    plt.xlabel('Epoch')
    plt.ylabel('F1_Values')
    plt.title('f1 during test')
    plt.legend()
    plt.savefig("f1.jpg")
    plt.show()
    
def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s]%(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)

    return logger
    
def DrawConfusionMatrix(save_model_name,val_dataloader):
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = pseudo_siamese_network().to(device)
    model.load_state_dict(torch.load(os.path.join(save_model_name,"best.pth")))
    model.eval()
    predict = []
    gt = []
    with torch.no_grad():
        for data_label in val_dataloader:
            x,target = data_label
            x,target = x.to(device),target.to(device)
            outputs = model(x)
            _, predicted = torch.max(outputs, 1)

            tmp_predict = predicted.cpu().detach().numpy()
            tmp_label = target.cpu().detach().numpy()

            if len(predict) == 0:
                predict = np.copy(tmp_predict)
                gt = np.copy(tmp_label)
            else:
                predict = np.hstack((predict,tmp_predict))
                gt = np.hstack((gt,tmp_label))

    cm = confusion_matrix(y_true=predict, y_pred=gt)   # 混淆矩阵
    disp = metrics.ConfusionMatrixDisplay(confusion_matrix=cm,display_labels = ['left','right','feet',"tongue"])
    disp.plot()
    save_hunxiao_path = os.path.join(save_model_name,'混淆矩阵.png')
    plt.savefig(save_hunxiao_path,dpi = 1000)

recall1= []
recall2= []
recall3 = []
recall4 = []
precision1 = []
precision2 = []
precision3 = []
precision4 = []
f1_1 = []
f1_2 = []
f1_3 = []
f1_4 = []
epoch_list = []    
accuracy_history = []
loss_history = []
test_accuracy = []
best_acc = 0



logger=get_logger(os.path.join(r"C:\Users\19067\Desktop\xiandianzi",'all_trail_exp.log'))

for epoch in range(0,Config.epochs):
    model.train()
    counter = []
    iteration_number = 0
    train_correct = 0
    total = 0
    correct = 0
    train_loss = 0
    for i,data in enumerate(train_dataloader,0): #enumerate防止重复抽取到相同数据,数据取完就可以结束一个epoch
        x1,x2,label = data
        #data = np.copy(data)
        
        x1,x2,label= x1.to(device),x2.to(device),label.to(device)
        
        optimizer.zero_grad() 
        output1,output2 = model(x1,x2)  
        loss = loss_fn(output1,output2,label) 
        loss.backward()   
        optimizer.step()  
        output = torch.concat((output1,output2))
        label = torch.concat((label,label))
        predicted=torch.argmax(output, 1)
        train_correct += (predicted == label).sum().item()
        total+=label.size(0) 
        
        train_loss += loss
    train_accuracy = train_correct / total
    train_loss /= len(train_dataloader)
    train_loss = train_loss.item()
    iteration_number += 1
    
    counter.append(iteration_number)
    accuracy_history.append(train_accuracy)
    loss_history.append(train_loss)
    
    # print("Epoch number {}\n Current Train  Accuracy {}\n Current Train loss {}\n".format
    #         (epoch, train_accuracy,train_loss))
    logger.info("Epoch number {}\n Current Train  Accuracy {}\n Current Train loss {}\n".format
            (epoch, train_accuracy,train_loss))
    
    with torch.no_grad():
        model.eval()
        test_correct = 0
        total =  0
        tensor_concat_pre_label = []
        label_item = []
        epoch_recall = 0
        epoch_precision = 0
        epoch_f1 = 0
        n_classes = 4
        target_num = torch.zeros((1, n_classes)) 
        predict_num = torch.zeros((1, n_classes))
        acc_num = torch.zeros((1, n_classes))

        for idx, data in enumerate(val_dataloader,0):
            x1,x2,label = data
            x1,x2,label = x1.to(device),x2.to(device),label.to(device)
            output1,output2 = model(x1,x2)
            output = torch.concat((output1,output2))
            label = torch.concat((label,label))
            predicted=torch.argmax(output, 1)
            test_correct += (predicted == label).sum().item()
            total+=label.size(0)
            # 1 PR/RE/F1 报告
            pred = predicted
            y_true = label.cpu()
            y_pred = pred.float().cpu()
            if len(tensor_concat_pre_label)==0:
                tensor_concat_pre_label = y_pred.clone()
                label_item = y_true.clone()
            else:
                tensor_concat_pre_label = torch.concat((tensor_concat_pre_label,y_pred))
                label_item = torch.concat((label_item,y_true)) 
        sklearn.metrics.classification_report(label_item,tensor_concat_pre_label)
        sklearn.metrics.accuracy_score(label_item, tensor_concat_pre_label)
        print(accuracy_score(label_item,tensor_concat_pre_label),classification_report(label_item,tensor_concat_pre_label))
        print(metrics.confusion_matrix(label_item,tensor_concat_pre_label))
        
        # logger.info(accuracy_score(label_item,tensor_concat_pre_label))
        #2 acc
        current_test_acc = test_correct / total
        test_accuracy.append(current_test_acc)
        print("测试acc: ",(current_test_acc) * 100,"%")
        #3 每一类别的pr、re、f1图
        pre_mask = torch.zeros(output.size()).scatter_(1, predicted.cpu().view(-1, 1), 1.)
        predict_num += pre_mask.sum(0)  # 得到数据中每类的预测量
        
        tar_mask = torch.zeros(output.size()).scatter_(1, label.cpu().view(-1, 1), 1.)
        target_num += tar_mask.sum(0)  # 得到数据中每类的数量
        
        acc_mask = pre_mask * tar_mask 
        acc_num += acc_mask.sum(0) # 得到各类别分类正确的样本数量

    recall = acc_num / target_num
    precision = acc_num / predict_num
    F1 = 2 * recall * precision / (recall + precision)
    recall  = recall.numpy()
    precision = precision.numpy()
    F1 = F1.numpy()
    recall_back1,recall_back2,recall_back3,recall_back4 = recall[:,0],recall[:,1],recall[:,2],recall[:,3]
    precision_back1,precision_back2,precision_back3,precision_back4 = precision[:,0],precision[:,1],precision[:,2],precision[:,3]
    F1_back1,F1_back2,F1_back3,F1_back4 = F1[:,0],F1[:,1],F1[:,2],F1[:,3]
    #accuracy = 100. * acc_num.sum(1) / target_num.sum(1)
    epoch_list.append(epoch)
    recall1.append(recall_back1)
    recall2.append(recall_back2)
    recall3.append(recall_back3)
    recall4.append(recall_back4)
    precision1.append(precision_back1)
    precision2.append(precision_back2)
    precision3.append(precision_back3)
    precision4.append(precision_back4)
    f1_1.append(F1_back1)  
    f1_2.append(F1_back2)
    f1_3.append(F1_back3)
    f1_4.append(F1_back4)

    if current_test_acc > best_acc and epoch>Config.epochs/2:
        best_acc = current_test_acc
        torch.save(model.state_dict(),'best.pth')
    
DrawConfusionMatrix(r"C:\Users\19067\Desktop\xiandianzi",val_dataloader)
 
plot_recall(epoch_list,recall1,recall2,recall3,recall4)
plot_precision(epoch_list,precision1,precision2,precision3,precision4)
plot_f1(epoch_list,f1_1,f1_2,f1_3,f1_4)
show_plot(accuracy_history,loss_history,test_accuracy) 


    

import torchvision.models as models
from torchsummary import summary
summary(model,(1,22,1000),batch_size=32,device="cuda")
print(model)

结果:

使用SGD优化器,

Lr=0.01,

交叉熵损失,

batch_size=32,

CUDA加速,2th Gen Intel(R) Core(TM) i7-12700   2.10 GHz 64位操作系统上跑数据,最终在800个epochs上模型拟合,最高测试分类ACC:65.972%

4类每一类别的PR、RE、F1、混淆、ACC结果如下:

结语:

这么一搭建,结果还可以,后续会对模型以及数据做其他改动,更新待续.............

猜你喜欢

转载自blog.csdn.net/mantoudamahou/article/details/134203385