4、八种机器学习方法用于DWT/PSD提取的EEG情感分类的研究

0.前言:

Hello,本篇文章不再探索EEG信号输入的形状问题,来讲一个比较好玩的机器学习。下面是我个人的一个研究。机器学习不必考虑数据输入的形状,最重要的一点在于特征的手动选择。在机器学习中,我们选择不同的特征作为分类器的输入,相当于深度学习中神经网络模型的调参。

最近继续看了一些论文,BCI信号作为输入,无论是原始数据作为输入还是提取了特征作为输入,都必须设计输入形状。有的论文说到,EEG信号输入必须设计特征作为输入,但其他实验证明,EEG原始信号作为输入仍然可以取得相当优秀的结果。

我发现,原始信号作为输入能够取得如此好的成绩,是因为人家设计的模型中各个超参数(kernel_size,filter_number,stride,padding,Activation Function... and etc)和原始信号的输入形状是息息相关,是非常有考究的。确定一个适合于特定输入形状的神经网络模型工作量是巨大的,在深度学习领域,这也是人们为何喜欢直接用前人建立好的模型,比如VGG/RestNet/EEGNet等;这也是在机器学习中人们设计完特征,直接在sklearn中直接调用各个模型的原因。当然了,机器学习中你可以自己写代码建立一个模型,比如写一个随机森林,我们可以对他的叶子节点数目进行修改,记得我之前教过一位研二学生,他也是搞脑电的,但硕导要求他用机器学习去做,他就是把特征设计好了,然后输入到sklearn调用的模型就完事了,我当时挺惊讶:硕士毕业这样就行了吗。他们实验团队都是这样搞得,甚至博士也是如此,只是做着重复的无意义的工作。

1、DEAP Datasets

先说一下本次研究用的数据:DEAP情感数据集。下面是数据集官网,需要你去邮件联系这家机构,说明来意,得到许可,才能下载。这里实验范式不多说,我给出数据地址:DEAP: A Dataset for Emotion Analysis using Physiological and Audiovisual Signals (qmul.ac.uk)

2、processing

预处理中,三次滤波提取了theta、alpha、beta的三个波形并做了加权平均。

3、Feature Extraction Methods

我使用了两种特征提取的方法:

小波变换(DWT):提取EEG时间域特征,原理不多说,代码如下

%for graphs and all
numfiles = 32;
mydata1 = cell(1, numfiles);
spec=[];
hdr=lowpass_filter();
hda=alpha_band_filter();
hdb=beta_band_filter();
filter_data1 = cell(1,numfiles);
fs=128;
opfilename='testFile_01';
myfilename = sprintf('s0%d.mat', 1);
load(myfilename);
datastart=128*3;
datalength=8064-128*3;
for video=1:1
    data1=zeros(1,datalength);
    for ii =1:datalength
        data1(1,ii)=data(video,4,datastart+ii);
    end
    
    %     Plot for PSD of raw data in microVolts
    psd_r=spectrum(data1,256);
    hpsd_r = dspdata.psd(psd_r,'Fs',fs); % Create a PSD data object.
    % plot(hpsd_r);
    xlim([0 300]);
    
    filter_data= filtfilt(hdr.Numerator,1,data1);
    data2=filter_data;
    avpow = norm(data2,2)^2/numel(data2);
    %     Plot for psd of filtered data
    psd_f=spectrum(data2,256);
    hpsd_f = dspdata.psd(psd_f,'Fs',fs); % Create a PSD data object.
    subplot(2,1,2);
    plot(hpsd_f);
    xlim([0 100]);
    hold all;
    
    avpow = norm(data1,2)^2/numel(data1);
    
    % %   delta band
    ts=(length(data1)/128);
    Wp = [1 4]/(fs/2); Ws = [0.5 4.5]/(fs/2);
    Rp = 3; Rs = 40;
    [n,Wn] = buttord(Wp,Ws,Rp,Rs);
    [z, p, k] = butter(n,Wn,'bandpass');
    [sos,g] = zp2sos(z,p,k);
    filt = dfilt.df2sos(sos,g);
    fd1 = filter(filt,data1);
    avpow1 = norm(fd1,2)^2/numel(fd1);
    
    % %  theta band
    Wp = [4 8]/(fs/2); Ws = [3.5 8.5]/(fs/2);
    [n,Wn] = buttord(Wp,Ws,Rp,Rs);
    [z, p, k] = butter(n,Wn,'bandpass');
    [sos,g] = zp2sos(z,p,k);
    filt = dfilt.df2sos(sos,g);
    fd2 = filter(filt,data1);
    avpow2 = norm(fd2,2)^2/numel(fd2);
    
    % %   alpha band
    Wp = [8 13]/(fs/2); Ws = [7.5 13.5]/(fs/2);
    [n,Wn] = buttord(Wp,Ws,Rp,Rs);
    [z, p, k] = butter(n,Wn,'bandpass');
    [sos,g] = zp2sos(z,p,k);
    filt = dfilt.df2sos(sos,g);
    fd3 = filter(filt,data1);
    fd3= filtfilt(hda.Numerator,1,data1);
    avpow3 = norm(fd3,2)^2/numel(fd3);
    
    % % beta band
    Wp = [13 30]/(fs/2); Ws = [12.5 30.5]/(fs/2);
    [n,Wn] = buttord(Wp,Ws,Rp,Rs);
    [z, p, k] = butter(n,Wn,'bandpass');
    [sos,g] = zp2sos(z,p,k);
    filt = dfilt.df2sos(sos,g);
    fd4 = filter(filt,data1);
    fd4= filtfilt(hdb.Numerator,1,data1);
    avpow4 = norm(fd4,2)^2/numel(fd4);
    
    sumpow=avpow1+avpow2+avpow3+avpow4;
    d_bpr=log10(avpow1/sumpow);
    t_bpr=log10(avpow2/sumpow);
    a_bpr=log10(avpow3/sumpow);
    b_bpr=log10(avpow4/sumpow);
    
    valence=labels(video,1);
    arousal=labels(video,2);
    dominance=labels(video,3);
    liking=labels(video,4);
    if(valence<5 && arousal<5)
        op=1;
    end
    if(valence<5 && arousal>5)
        op=2;
    end
    if(valence>5 && arousal<5)
        op=3;
    end
    if(valence>5 && arousal>5)
        op=4;
    end
    % fprintf(fid,'%d,%d,%.3f,%.3f,%.3f,%.3f,%d',1,video,d_bpr,t_bpr,a_bpr,b_bpr,op);
    fprintf(fid,'\n');
    
end %the tetcase loop
fclose(myfile);

功率谱密度(PSD):提取EEG的时频特征,代码如下:

numfiles = 32;
mydata1 = cell(1, numfiles);
spec=[];
hdr=lowpass_filter();
hda=alpha_band_filter();
hdb=beta_band_filter();
filter_data1 = cell(1,numfiles);
fs=128;
for channel=1:32
    fprintf('\ncreating testfile number %d:\n',channel);
    if(channel<10)
       opfilename ='testFile_0';
    else
        opfilename='testFile_';
    end
    filename=[opfilename int2str(channel) '.txt'];filename;
    fid = fopen( filename, 'wt' );
    fprintf(fid,'user,video,delta,theta,alpha,beta,op\n');
    for i = 1:numfiles
       fprintf('\nworking on file number %d:', i);
           if(i<10)
            myfilename = sprintf('s0%d.mat', i);
           else
             myfilename = sprintf('s%d.mat', i);
            end
          load(myfilename);
          for video=1:40
              data1=zeros(1,8064);
              for ii =1:8064
                  data1(1,ii)=data(video,channel,ii);
              end
              avpow = norm(data1,2)^2/numel(data1);

              % %   delta band      
              ts=(length(data1)/128);
              Wp = [1 4]/(fs/2); Ws = [0.5 4.5]/(fs/2);
              Rp = 3; Rs = 40;
              [n,Wn] = buttord(Wp,Ws,Rp,Rs);
              [z, p, k] = butter(n,Wn,'bandpass');
              [sos,g] = zp2sos(z,p,k);
              filt = dfilt.df2sos(sos,g);
              fd1 = filter(filt,data1);
              avpow1 = norm(fd1,2)^2/numel(fd1);

              % %  theta band      
              Wp = [4 8]/(fs/2); Ws = [3.5 8.5]/(fs/2);
              [n,Wn] = buttord(Wp,Ws,Rp,Rs);
              [z, p, k] = butter(n,Wn,'bandpass');
              [sos,g] = zp2sos(z,p,k);
              filt = dfilt.df2sos(sos,g);
              fd2 = filter(filt,data1);
              avpow2 = norm(fd2,2)^2/numel(fd2);

              % %   alpha band 
              Wp = [8 13]/(fs/2); Ws = [7.5 13.5]/(fs/2);
              [n,Wn] = buttord(Wp,Ws,Rp,Rs);
              [z, p, k] = butter(n,Wn,'bandpass');
              [sos,g] = zp2sos(z,p,k);
              filt = dfilt.df2sos(sos,g);
              fd3 = filter(filt,data1);
              fd3= filtfilt(hda.Numerator,1,data1);
              avpow3 = norm(fd3,2)^2/numel(fd3);

              % % beta band
              Wp = [13 30]/(fs/2); Ws = [12.5 30.5]/(fs/2);
              [n,Wn] = buttord(Wp,Ws,Rp,Rs);
              [z, p, k] = butter(n,Wn,'bandpass');
              [sos,g] = zp2sos(z,p,k);
              filt = dfilt.df2sos(sos,g);
              fd4 = filter(filt,data1);
              fd4= filtfilt(hdb.Numerator,1,data1);
              avpow4 = norm(fd4,2)^2/numel(fd4);

            % %gamma band
            %    Wp = [30 50]/(fs/2); Ws = [29.5 50.5]/(fs/2);
            %   [n,Wn] = buttord(Wp,Ws,Rp,Rs);
            %    [z, p, k] = butter(n,Wn,'bandpass');
            %    [sos,g] = zp2sos(z,p,k);
            %   filt = dfilt.df2sos(sos,g);
            %    fd5 = filter(filt,data1);
            %    avpow5 = norm(fd5,2)^2/numel(fd5);

            sumpow=avpow1+avpow2+avpow3+avpow4;
            d_bpr=log10(avpow1/sumpow);
            t_bpr=log10(avpow2/sumpow);
            a_bpr=log10(avpow3/sumpow);
            b_bpr=log10(avpow4/sumpow);
            
            valence=labels(video,1);
            arousal=labels(video,2);
            dominance=labels(video,3);
            liking=labels(video,4);
            if(valence<5 && arousal<5)
                op=1;
            end
            if(valence<5 && arousal>5)
                    op=2;
            end
            if(valence>5 && arousal<5)
                    op=3;
            end
            if(valence>5 && arousal>5)
                    op=4;
            end
            fprintf('|');
            fprintf(fid,'%d,%d,%.3f,%.3f,%.3f,%.3f,%d',i,video,d_bpr,t_bpr,a_bpr,b_bpr,op);
            fprintf(fid,'\n');
            
          end %the tetcase loop
    end %the file loop
    fclose(fid);
end

4、Machine Learning Methods

直接在sklearn库中调用八种机器学习模型进行分类,8个模型:KMeans、SVC(linear)、SVC(rbf)、GaussianNB、RandomForestClassifier、DecisionTresClassifier、LogisticRegression、KNN,并做10折交叉验证。

DWT分类结果:

随机选取第二被试的数据

import numpy as np
from sklearn import neighbors
import pandas as pd
from sklearn import model_selection
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.cluster import KMeans
from sklearn.naive_bayes import GaussianNB
from sklearn import linear_model
from sklearn.linear_model import LogisticRegression
from sklearn import svm
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold 
from sklearn.model_selection import cross_val_score 
import warnings
warnings.filterwarnings("ignore")

df=pd.read_csv('testFile_02.txt')
df.drop(['user','video'],axis=1,inplace=True)
X=np.array(df.drop(['combined'],axis=1))
y=np.array(df['combined'])



kfold = KFold(n_splits=10) 
xyz=[]
accuracy=[]
std=[]
classifiers=['Linear Svm','Radial Svm','Logistic Regression','KNN','Decision Tree','Naive Bayes','Random Forest','KMeans']
models=[svm.SVC(kernel='linear'),svm.SVC(kernel='rbf'),
        LogisticRegression(),neighbors.KNeighborsClassifier(n_neighbors=9),
        tree.DecisionTreeClassifier(),GaussianNB(),RandomForestClassifier(n_estimators=100),KMeans()
        ]

for i in models:
    model = i
    cv_result = cross_val_score(model,X,y, cv = kfold,scoring = "accuracy")
    xyz.append(cv_result.mean())
    std.append(cv_result.std())
    accuracy.append(cv_result)
    print(accuracy)

[array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ])]
[array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.934375  , 0.97421875, 0.94140625, 0.9671875 , 0.97421875,
       0.94921875, 0.9390625 , 0.9234375 , 0.96484375, 0.93203125])]
[array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.934375  , 0.97421875, 0.94140625, 0.9671875 , 0.97421875,
       0.94921875, 0.9390625 , 0.9234375 , 0.96484375, 0.93203125]), array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ])]
[array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.934375  , 0.97421875, 0.94140625, 0.9671875 , 0.97421875,
       0.94921875, 0.9390625 , 0.9234375 , 0.96484375, 0.93203125]), array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.87421875, 0.91171875, 0.87890625, 0.92578125, 0.9265625 ,
       0.88671875, 0.9015625 , 0.8546875 , 0.91328125, 0.88125   ])]
[array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.934375  , 0.97421875, 0.94140625, 0.9671875 , 0.97421875,
       0.94921875, 0.9390625 , 0.9234375 , 0.96484375, 0.93203125]), array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.87421875, 0.91171875, 0.87890625, 0.92578125, 0.9265625 ,
       0.88671875, 0.9015625 , 0.8546875 , 0.91328125, 0.88125   ]), array([0.94921875, 0.971875  , 0.9484375 , 0.96015625, 0.96171875,
       0.940625  , 0.9421875 , 0.9359375 , 0.95      , 0.9296875 ])]
[array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.934375  , 0.97421875, 0.94140625, 0.9671875 , 0.97421875,
       0.94921875, 0.9390625 , 0.9234375 , 0.96484375, 0.93203125]), array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.87421875, 0.91171875, 0.87890625, 0.92578125, 0.9265625 ,
       0.88671875, 0.9015625 , 0.8546875 , 0.91328125, 0.88125   ]), array([0.94921875, 0.971875  , 0.9484375 , 0.96015625, 0.96171875,
       0.940625  , 0.9421875 , 0.9359375 , 0.95      , 0.9296875 ]), array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ])]
[array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.934375  , 0.97421875, 0.94140625, 0.9671875 , 0.97421875,
       0.94921875, 0.9390625 , 0.9234375 , 0.96484375, 0.93203125]), array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.87421875, 0.91171875, 0.87890625, 0.92578125, 0.9265625 ,
       0.88671875, 0.9015625 , 0.8546875 , 0.91328125, 0.88125   ]), array([0.94921875, 0.971875  , 0.9484375 , 0.96015625, 0.96171875,
       0.940625  , 0.9421875 , 0.9359375 , 0.95      , 0.9296875 ]), array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.9765625 , 0.99921875, 0.97578125, 0.9921875 , 0.984375  ,
       0.97578125, 0.9609375 , 0.953125  , 0.97578125, 0.9515625 ])]
[array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.934375  , 0.97421875, 0.94140625, 0.9671875 , 0.97421875,
       0.94921875, 0.9390625 , 0.9234375 , 0.96484375, 0.93203125]), array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.87421875, 0.91171875, 0.87890625, 0.92578125, 0.9265625 ,
       0.88671875, 0.9015625 , 0.8546875 , 0.91328125, 0.88125   ]), array([0.94921875, 0.971875  , 0.9484375 , 0.96015625, 0.96171875,
       0.940625  , 0.9421875 , 0.9359375 , 0.95      , 0.9296875 ]), array([0.9765625, 1.       , 0.9765625, 0.9921875, 0.984375 , 0.9765625,
       0.9609375, 0.953125 , 0.9765625, 0.953125 ]), array([0.9765625 , 0.99921875, 0.97578125, 0.9921875 , 0.984375  ,
       0.97578125, 0.9609375 , 0.953125  , 0.97578125, 0.9515625 ]), array([0.07734375, 0.128125  , 0.125     , 0.11484375, 0.10703125,
       0.07890625, 0.1203125 , 0.14609375, 0.1109375 , 0.0609375 ])]

PSD分类结果:

随机选取第三被试数据

import numpy as np
from sklearn import neighbors
import pandas as pd
from sklearn import model_selection
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.cluster import KMeans
from sklearn.naive_bayes import GaussianNB
from sklearn import linear_model
from sklearn.linear_model import LogisticRegression
from sklearn import svm
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold 
from sklearn.model_selection import cross_val_score 
import warnings
warnings.filterwarnings("ignore")

df=pd.read_csv('testFile_03.txt')
df.drop(['user','video'],axis=1,inplace=True)
X=np.array(df.drop(['op'],axis=1))
y=np.array(df['op'])



kfold = KFold(n_splits=10) 
xyz=[]
accuracy=[]
std=[]
classifiers=['Linear Svm','Radial Svm','Logistic Regression','KNN','Decision Tree','Naive Bayes','Random Forest','KMeans']
models=[svm.SVC(kernel='linear'),svm.SVC(kernel='rbf'),
        LogisticRegression(),neighbors.KNeighborsClassifier(n_neighbors=9),
        tree.DecisionTreeClassifier(),GaussianNB(),RandomForestClassifier(n_estimators=100),KMeans()
        ]

for i in models:
    model = i
    cv_result = cross_val_score(model,X,y, cv = kfold,scoring = "accuracy")
    xyz.append(cv_result.mean())
    std.append(cv_result.std())
    accuracy.append(cv_result)
    print(accuracy)

[array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875])]
[array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875])]
[array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.25     , 0.34375  , 0.3671875, 0.3515625, 0.25     , 0.40625  ,
       0.3984375, 0.3359375, 0.4375   , 0.3125   ])]
[array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.25     , 0.34375  , 0.3671875, 0.3515625, 0.25     , 0.40625  ,
       0.3984375, 0.3359375, 0.4375   , 0.3125   ]), array([0.25     , 0.2578125, 0.234375 , 0.3515625, 0.2578125, 0.2421875,
       0.3203125, 0.2734375, 0.3203125, 0.1875   ])]
[array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.25     , 0.34375  , 0.3671875, 0.3515625, 0.25     , 0.40625  ,
       0.3984375, 0.3359375, 0.4375   , 0.3125   ]), array([0.25     , 0.2578125, 0.234375 , 0.3515625, 0.2578125, 0.2421875,
       0.3203125, 0.2734375, 0.3203125, 0.1875   ]), array([0.203125 , 0.2734375, 0.359375 , 0.359375 , 0.25     , 0.2265625,
       0.28125  , 0.2890625, 0.3125   , 0.28125  ])]
[array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.25     , 0.34375  , 0.3671875, 0.3515625, 0.25     , 0.40625  ,
       0.3984375, 0.3359375, 0.4375   , 0.3125   ]), array([0.25     , 0.2578125, 0.234375 , 0.3515625, 0.2578125, 0.2421875,
       0.3203125, 0.2734375, 0.3203125, 0.1875   ]), array([0.203125 , 0.2734375, 0.359375 , 0.359375 , 0.25     , 0.2265625,
       0.28125  , 0.2890625, 0.3125   , 0.28125  ]), array([0.2734375, 0.3203125, 0.3359375, 0.3203125, 0.25     , 0.2890625,
       0.328125 , 0.3125   , 0.25     , 0.2890625])]
[array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.25     , 0.34375  , 0.3671875, 0.3515625, 0.25     , 0.40625  ,
       0.3984375, 0.3359375, 0.4375   , 0.3125   ]), array([0.25     , 0.2578125, 0.234375 , 0.3515625, 0.2578125, 0.2421875,
       0.3203125, 0.2734375, 0.3203125, 0.1875   ]), array([0.203125 , 0.2734375, 0.359375 , 0.359375 , 0.25     , 0.2265625,
       0.28125  , 0.2890625, 0.3125   , 0.28125  ]), array([0.2734375, 0.3203125, 0.3359375, 0.3203125, 0.25     , 0.2890625,
       0.328125 , 0.3125   , 0.25     , 0.2890625]), array([0.1875   , 0.28125  , 0.3359375, 0.3203125, 0.28125  , 0.2890625,
       0.234375 , 0.3046875, 0.296875 , 0.265625 ])]
[array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.2578125, 0.34375  , 0.375    , 0.3515625, 0.234375 , 0.4296875,
       0.390625 , 0.375    , 0.4375   , 0.3046875]), array([0.25     , 0.34375  , 0.3671875, 0.3515625, 0.25     , 0.40625  ,
       0.3984375, 0.3359375, 0.4375   , 0.3125   ]), array([0.25     , 0.2578125, 0.234375 , 0.3515625, 0.2578125, 0.2421875,
       0.3203125, 0.2734375, 0.3203125, 0.1875   ]), array([0.203125 , 0.2734375, 0.359375 , 0.359375 , 0.25     , 0.2265625,
       0.28125  , 0.2890625, 0.3125   , 0.28125  ]), array([0.2734375, 0.3203125, 0.3359375, 0.3203125, 0.25     , 0.2890625,
       0.328125 , 0.3125   , 0.25     , 0.2890625]), array([0.1875   , 0.28125  , 0.3359375, 0.3203125, 0.28125  , 0.2890625,
       0.234375 , 0.3046875, 0.296875 , 0.265625 ]), array([0.09375  , 0.0546875, 0.1015625, 0.140625 , 0.125    , 0.1484375,
       0.0625   , 0.0859375, 0.15625  , 0.0859375])]

5、结语:

通过结果对比得知,不同的特征是机器学习中影响分类准确率的重大因素。我们试想一下,为何DWT特征分类准确率每折在96%左右,而PSD分类只在30%-40%左右呢?是因为人类的情感有着强烈的时域信息吗?试着结合该数据集的原始信号波形图查看波峰,伴随着人类情感的变化,波峰是不是起伏较大呢?波峰起伏较大代表着什么呢?提取的DWT特征对应着峰值的哪些部分呢?

猜你喜欢

转载自blog.csdn.net/mantoudamahou/article/details/132277219
今日推荐