基于word2vec对十年的新闻文本数据做道琼斯指数预测

版权声明:转载请声明转自Juanlyjack https://blog.csdn.net/m0_38088359/article/details/82860617

1、数据描述:
(1)新闻数据:从Reddit WorldNews Channel(/ r / worldnews)抓获历史新闻头条。它们按reddit用户的投票排名,并且只有前25个标题被考虑用于单个日期。(范围:2008-06-08至2016-07-01)
(2)股票数据:道琼斯工业平均指数(DJIA)用于“证明这一概念”。(范围:2008-08-08至2016-07-01)
文件数据格式:csv
Combined_News_DJIA.csv
提供了27个列的组合数据集。第一列是“日期”,第二列是“标签”,以下是从“Top1”到“Top25”的新闻标题。
其中,当DJIA Adj Close值上升或保持不变时,“1”;当DJIA Adj Close值下降时,“0”。
对于任务评估,使用2008-08-08至2014-12-31的数据作为训练集,然后测试集将是以下两年的数据(从2015-01-02到2016-07-01)。这大约是80%/ 20%的分割。
最终结果使用AUC作为评估指标。
下载链接:https://pan.baidu.com/s/12Y2fVIJ7yhnlJGQCkyYysg 密码:xwg8
2、文本预处理
(1)观测数据

import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score
from datetime import date
#读取数据
data = pd.read_csv('Combined_News_DJIA.csv',header=0,encoding='utf8')
#观察前五行数据,看起来特别的简单直观。如果是1,那么当日的DJIA就提高或者不变了。如果是1,那么DJIA那天就是跌了。
data.head()

(2)分割数据集(测试/训练集)

train = data[data['Date'] < '2015-01-01']
test = data[data['Date'] > '2014-12-31']

然后把每条新闻做成一个单独的句子,集合在一起,然后制造一个所有句子的集合作为语料库,这里我们用flatten()来将所有的文本扁平化成一个numpy数组作为语料库,但注意X_train和X_test不能随便的扁平化,因为他们要和y_train和y_test对应。

X_train = train[train.columns[2:]]
corpus = X_train.values.flatten().astype(str)

X_train = X_train.values.astype(str)
X_train = np.array([' '.join(x) for x in X_train])
X_test = test[test.columns[2:]]
X_test = X_test.values.astype(str)
X_test = np.array([' '.join(x) for x in X_test])
y_train = train['Label'].values
y_test = test['Label'].values

分割后如下所示:

corpus[:3]

输出:

array([ 'b"Georgia \'downs two Russian warplanes\' as countries move to brink of war"',
       "b'BREAKING: Musharraf to be impeached.'",
       "b'Russia Today: Columns of troops roll into South Ossetia; footage from fighting (YouTube)'"], 
      dtype='<U312')
X_train[:1]

输出:

array([ 'b"Georgia \'downs two Russian warplanes\' as countries move to brink of war" b\'BREAKING: Musharraf to be impeached.\' b\'Russia Today: Columns of troops roll into South Ossetia; footage from fighting (YouTube)\' b\'Russian tanks are moving towards the capital of South Ossetia, which has reportedly been completely destroyed by Georgian artillery fire\' b"Afghan children raped with \'impunity,\' U.N. official says - this is sick, a three year old was raped and they do nothing" b\'150 Russian tanks have entered South Ossetia whilst Georgia shoots down two Russian jets.\' b"Breaking: Georgia invades South Ossetia, Russia warned it would intervene on SO\'s side" b"The \'enemy combatent\' trials are nothing but a sham: Salim Haman has been sentenced to 5 1/2 years, but will be kept longer anyway just because they feel like it." b\'Georgian troops retreat from S. Osettain capital, presumably leaving several hundred people killed. [VIDEO]\' b\'Did the U.S. Prep Georgia for War with Russia?\' b\'Rice Gives Green Light for Israel to Attack Iran: Says U.S. has no veto over Israeli military ops\' b\'Announcing:Class Action Lawsuit on Behalf of American Public Against the FBI\' b"So---Russia and Georgia are at war and the NYT\'s top story is opening ceremonies of the Olympics?  What a fucking disgrace and yet further proof of the decline of journalism." b"China tells Bush to stay out of other countries\' affairs" b\'Did World War III start today?\' b\'Georgia Invades South Ossetia - if Russia gets involved, will NATO absorb Georgia and unleash a full scale war?\' b\'Al-Qaeda Faces Islamist Backlash\' b\'Condoleezza Rice: "The US would not act to prevent an Israeli strike on Iran." Israeli Defense Minister Ehud Barak: "Israel is prepared for uncompromising victory in the case of military hostilities."\' b\'This is a busy day:  The European Union has approved new sanctions against Iran in protest at its nuclear programme.\' b"Georgia will withdraw 1,000 soldiers from Iraq to help fight off Russian forces in Georgia\'s breakaway region of South Ossetia" b\'Why the Pentagon Thinks Attacking Iran is a Bad Idea - US News &amp; World Report\' b\'Caucasus in crisis: Georgia invades South Ossetia\' b\'Indian shoe manufactory  - And again in a series of "you do not like your work?"\' b\'Visitors Suffering from Mental Illnesses Banned from Olympics\' b"No Help for Mexico\'s Kidnapping Surge"'], 
      dtype='<U4424')

目标值如下:

y_train[:5]

输出:

array([0, 1, 0, 0, 1])

接下来对数据进行分词,以往都是用jieba,此次将引用nltk第三方库的word_tokenize来进行分词。其中要注意corpus和X_train中的每一个数据代表的意义不同,corpus每一个数据代表一个句子,而X_train每一个数据代表一天的文章所有句子集合(对应每个label)。

from nltk.tokenize import word_tokenize

corpus = [word_tokenize(x) for x in corpus]
X_train = [word_tokenize(x) for x in X_train]
X_test = [word_tokenize(x) for x in X_test]

预处理:
*小写
*删除停止词
*删除数字与符号
*使用lemma让所有的单词统一格式,如除去各种时态或者人称等.
其中停用词由nltk.corpus库中的stopwords提供,如果碰到报错,只要按照报错提示去下载对应的nltk包中的语料即可。

# 停止词
from nltk.corpus import stopwords
stop = stopwords.words('english')

# 数字
import re
def hasNumbers(inputString):
    return bool(re.search(r'\d', inputString))

# 特殊符号
def isSymbol(inputString):
    return bool(re.match(r'[^\w]', inputString))

# lemma
from nltk.stem import WordNetLemmatizer
wordnet_lemmatizer = WordNetLemmatizer()

def check(word):
    """
    如果需要这个单词,则True
    如果应该去除,则False
    """
    word= word.lower()
    if word in stop:
        return False
    elif hasNumbers(word) or isSymbol(word):
        return False
    else:
        return True

# 把上面的方法综合起来
def preprocessing(sen):
    res = []
    for word in sen:
        if check(word):
            # 这一段的用处仅仅是去除python里面byte存str时候留下的标识。。之前数据没处理好,其他case里不会有这个情况
            word = word.lower().replace("b'", '').replace('b"', '').replace('"', '').replace("'", '')
            res.append(wordnet_lemmatizer.lemmatize(word))
    return res

分别对corpus,测试\训练集做处理:

corpus = [preprocessing(x) for x in corpus]
X_train = [preprocessing(x) for x in X_train]
X_test = [preprocessing(x) for x in X_test]

3、训练NLP模型
(1)简单介绍word2vec:word2vec算法包括skip-gram和CBOW模型,使用分层softmax或负抽样。word2vec被集成在了gensim库当中。
各个参数详解:
gensim.models.word2vec.Word2Vec(sentences=None,size=100,alpha=0.025,window=5, min_count=5, max_vocab_size=None, sample=0.001,seed=1, workers=3,min_alpha=0.0001, sg=0, hs=0, negative=5,cbow_mean=1, hashfxn=,iter=5,null_word=0, trim_rule=None, sorted_vocab=1, batch_words=10000)

· sentences:可以是一个·ist,对于大语料集,建议使用BrownCorpus,Text8Corpus或ineSentence构建。
· sg: 用于设置训练算法,默认为0,对应CBOW算法;sg=1则采用skip-gram算法。
· size:是指特征向量的维度,默认为100。大的size需要更多的训练数据,但是效果会更好. 推荐值为300-800。
· window:表示当前词与预测词在一个句子中的最大距离是多少
· alpha: 是学习速率
· seed:用于随机数发生器。与初始化词向量有关。
· min_count: 可以对字典做截断. 词频少于min_count次数的单词会被丢弃掉, 默认值为5
· max_vocab_size: 设置词向量构建期间的RAM限制。如果所有独立单词个数超过这个,则就消除掉其中最不频繁的一个。每一千万个单词需要大约1GB的RAM。设置成None则没有限制。
· sample: 高频词汇的随机降采样的配置阈值,默认为1e-3,范围是(0,1e-5)
· workers参数控制训练的并行数。
· hs: 如果为1则会采用hierarchica·softmax技巧。如果设置为0(defaut),则negative sampling会被使用。
· negative: 如果>0,则会采用negativesamp·ing,用于设置多少个noise words
· cbow_mean: 如果为0,则采用上下文词向量的和,如果为1(defaut)则采用均值。只有使用CBOW的时候才起作用。
· hashfxn:hash函数来初始化权重。默认使用python的hash函数
· iter:迭代次数,默认为5
· trim_rule:用于设置词汇表的整理规则,指定那些单词要留下,哪些要被删除。可以设置为None(min_count会被使用)或者一个接受()并返回RUE_DISCARD,utis.RUE_KEEP或者utis.RUE_DEFAUT的函数。
· sorted_vocab: 如果为1(defaut),则在分配word index 的时候会先对单词基于频率降序排序。
· batch_words:每一批的传递给线程的单词的数量,默认为10000

from gensim.models.word2vec import Word2Vec
model = Word2Vec(corpus, size=500, window=5, min_count=5, workers=4)
model['ok']

'ok’这个词包括在了语料库中,它的向量维度为我们设定的500维,输出如下:

array([ 0.10630998,  0.14210293,  0.00959234,  0.03234905, -0.11406371,
       -0.04187847, -0.10550572,  0.01617744, -0.06915743, -0.09991685,
        0.0097149 ,  0.02334215,  0.06395814,  0.03067525,  0.01015951,
       -0.0262022 ,  0.00829123, -0.08188825,  0.017802  , -0.0435061 ,
        0.01704133,  0.1207227 ,  0.05831534, -0.02610667,  0.09159925,
        0.09652074,  0.00345927, -0.06801105,  0.08994583,  0.10310271,
        0.08756009,  0.04386705,  0.1464994 ,  0.08558127,  0.00778829,
        0.07385497, -0.07248272,  0.10217876, -0.06180608,  0.03155032,
       -0.03693158,  0.03071733,  0.02307764,  0.0266828 , -0.01686509,
        0.12422591,  0.08230562,  0.05776121, -0.09978634,  0.02537541,
       -0.06936303,  0.00428672, -0.08756024, -0.12032714, -0.07839552,
       -0.00336692,  0.01421487,  0.00830022, -0.0212637 , -0.03765199,
        0.02266566,  0.02933999, -0.03737787, -0.03587721,  0.01740207,
        0.01884088,  0.07537359,  0.0796692 ,  0.03100466,  0.02930304,
        0.06078303,  0.0516209 , -0.06951252,  0.03517171,  0.04363167,
        0.02802708, -0.02252885,  0.01897138,  0.05065971,  0.01507939,
        0.08021148,  0.06284177, -0.00828255,  0.02805166, -0.04506287,
       -0.03337098, -0.15036598, -0.08192373, -0.110293  ,  0.10606571,
       -0.03109092,  0.01939313, -0.06605176, -0.02609161,  0.0027284 ,
        0.03812803,  0.03839679, -0.02783109,  0.12833288,  0.06448438,
       -0.04182523,  0.0375215 ,  0.02423172,  0.03802066,  0.0657379 ,
       -0.01455859, -0.09476987,  0.01737102,  0.08638109,  0.05465818,
        0.06491291,  0.04766256,  0.01754146,  0.11373118, -0.01488762,
       -0.03831841,  0.02579844,  0.08154879,  0.11227204, -0.05088274,
        0.00875485,  0.09823024,  0.02201955,  0.05845467,  0.10193563,
        0.05104215,  0.0158309 ,  0.09639726,  0.00914135,  0.02546973,
        0.13866955, -0.05752271,  0.03832977, -0.10237459, -0.08396969,
        0.03217989,  0.02531725, -0.0626449 ,  0.05346811,  0.01434826,
        0.05855142, -0.05202417,  0.02953455, -0.01062565, -0.0900347 ,
        0.07571767,  0.03393933, -0.05042863,  0.01061901,  0.01457398,
        0.00845148, -0.09595269,  0.07767481, -0.14466994,  0.00041045,
        0.01950366, -0.04252251, -0.13291493,  0.08225287, -0.07122978,
       -0.0293374 ,  0.07360036, -0.0381012 ,  0.04461425, -0.1385132 ,
        0.02107334, -0.02360223,  0.0758872 ,  0.06353981, -0.1180307 ,
       -0.16683382,  0.02000329, -0.05918222,  0.01400798,  0.06885161,
        0.01744867, -0.05448931,  0.06055044,  0.04077809,  0.06905974,
       -0.02946478,  0.01766725,  0.04666409,  0.04484922, -0.04900419,
        0.02925754, -0.08012831,  0.01098989, -0.03396216,  0.00939972,
       -0.01157847, -0.0226005 ,  0.07821315,  0.09635648, -0.03415763,
       -0.05766468,  0.09330291,  0.0492701 , -0.03537293,  0.07540825,
        0.05161836,  0.19003753, -0.04706846,  0.02586077, -0.05152024,
        0.01372254,  0.0632927 ,  0.00351941,  0.12191376,  0.0331907 ,
       -0.00720628,  0.1097018 , -0.00390838, -0.06126651, -0.00396858,
       -0.014973  , -0.13959153,  0.03812763,  0.07050537, -0.0261981 ,
        0.07624791, -0.01999989,  0.08195515,  0.0138115 , -0.02143659,
       -0.1221185 , -0.0689826 ,  0.00108599,  0.03837402,  0.02080408,
        0.08491726,  0.06150453, -0.04688453, -0.00450647, -0.11543091,
       -0.06753737,  0.14967291,  0.05745384,  0.08077233,  0.00245412,
       -0.04259247,  0.01369479, -0.02748694, -0.05678632,  0.01031494,
        0.12646398,  0.07089368, -0.02903853,  0.00815106, -0.02404015,
        0.17015406,  0.00533959, -0.02955618,  0.0456912 ,  0.01093624,
       -0.05910192,  0.08882368,  0.09565608,  0.00769056,  0.05817396,
       -0.03566754, -0.07642328, -0.00361639, -0.00116393, -0.03274405,
        0.12187891,  0.06580611,  0.08087548,  0.0056507 ,  0.07588522,
        0.02990477, -0.0159376 , -0.07681391, -0.10478649, -0.03572183,
        0.08408619,  0.07956488, -0.04065475, -0.05456991, -0.00155666,
       -0.01889308, -0.03381919,  0.01810199,  0.07423308, -0.06805349,
        0.08052777, -0.05395509,  0.05253442,  0.09202667, -0.0088914 ,
       -0.04823348, -0.03835667,  0.02839028,  0.02154875,  0.05142747,
        0.07458628,  0.02765298,  0.01020699,  0.03043701, -0.04790074,
       -0.06363068, -0.01153451,  0.02389097,  0.04536989, -0.01082927,
        0.05113677, -0.01200369,  0.04977348, -0.11610536,  0.00254656,
        0.01619265,  0.09359486,  0.07980918, -0.06380153,  0.03626567,
        0.05727048, -0.09218347,  0.04325103,  0.07287478, -0.12025584,
        0.00535176, -0.09048352, -0.02627514,  0.01658515, -0.06529512,
       -0.00386042,  0.06642924, -0.01560184,  0.07863382, -0.1937443 ,
       -0.03262932, -0.06985268,  0.01623991,  0.03910165,  0.00950554,
        0.1285626 , -0.11118279, -0.03705613, -0.01390279,  0.04396483,
        0.0496658 ,  0.04527789, -0.03675282, -0.03847364, -0.03028795,
        0.00923345, -0.02778282,  0.02306183,  0.02428048,  0.01268121,
        0.01417804, -0.07867095, -0.14611882,  0.04435711,  0.03162471,
        0.02029084,  0.00268042,  0.07459821, -0.03080082, -0.0423187 ,
        0.16432528, -0.03147165, -0.08362244,  0.07195573, -0.06889871,
        0.11954063,  0.10902878,  0.03131651, -0.04852001,  0.01603564,
        0.08147104,  0.02034297, -0.04282593, -0.06261195, -0.01161425,
        0.0199711 ,  0.041841  , -0.01313252, -0.06413952, -0.01367164,
        0.03570195,  0.00708226, -0.04031646,  0.04528132, -0.02119712,
       -0.05654348, -0.02537308, -0.07106056, -0.03181816,  0.15453419,
        0.06198772, -0.06934739,  0.09368483, -0.03431397, -0.04964959,
       -0.04364007, -0.00948998,  0.05827446, -0.0734807 ,  0.12710157,
        0.03669912, -0.10625352, -0.06186994, -0.10142536,  0.05993845,
        0.08987869,  0.04097576, -0.03678257,  0.0762139 ,  0.00946183,
        0.03653205,  0.05962306, -0.02681629, -0.02996796, -0.06433427,
        0.00521787,  0.00777637,  0.02280867,  0.08364372,  0.01565453,
        0.01401456, -0.04469008,  0.00916688, -0.01250234,  0.01992061,
       -0.01215584,  0.14897361, -0.00532461, -0.00190283, -0.03165984,
       -0.02664695,  0.02463158,  0.02106988,  0.06877461, -0.02750051,
        0.03929004, -0.04451486, -0.00840878, -0.03764717, -0.07900967,
       -0.03950715, -0.04896978, -0.05955911, -0.03691145, -0.03284043,
       -0.01727171,  0.11458033,  0.09035305, -0.08975063, -0.08778729,
        0.06654128, -0.02913619, -0.02908804, -0.06608992, -0.01623631,
        0.05511775,  0.00757801,  0.09175   , -0.0486589 ,  0.08999854,
       -0.00039395, -0.04856401, -0.01378971, -0.01551514, -0.04831987,
        0.13319787,  0.119831  ,  0.05863398, -0.15051669,  0.02254639,
        0.09188102, -0.02847377,  0.00488674, -0.04727001,  0.0014068 ,
        0.02884379,  0.01815874,  0.04246284,  0.06079307, -0.00693732,
        0.01785832,  0.07241053, -0.07631   ,  0.0710271 , -0.0016337 ,
       -0.023894  ,  0.04001028,  0.03364684,  0.04026045, -0.00322365,
       -0.00919932, -0.06855992, -0.10619757, -0.03903317, -0.00156592,
       -0.00353619, -0.12425131, -0.01167665, -0.02533508,  0.00049084],
      dtype=float32)

用NLP模型表达我们的数据集,我们这儿理应只得到了每个单词的vec,但由于我们文本本身的量很小,此处我们可以把所有的单词的vector拿过来取个平均值(当然实际当中没那么简单:)

def get_SenVector(word_list):
    res = np.zeros([500])
    count = 0
    for word in word_list:
        if word in model:
            res += model[word]
            count += 1
    return res/count   #取平均值
X_train = [get_SenVector(x) for x in X_train]
X_test = [get_SenVector(x) for x in X_test]

print(X_train[0])
[ 2.00956426e-01  2.79742372e-01  5.43111131e-02  2.32447913e-02
 -2.03764850e-01 -1.11120168e-01 -1.91787604e-01  6.81927016e-02
 -1.20990735e-01 -1.93873805e-01  4.95753614e-02  1.77553677e-02
  9.37306275e-02  7.68671597e-02  2.84631399e-02  3.75112822e-03
  1.86445602e-03 -9.60629956e-02 -1.31615048e-03 -1.04799383e-01
 -9.76756437e-03  2.24999646e-01  1.46699985e-01 -3.96265102e-02
  1.51601496e-01  1.58062636e-01  9.68156978e-03 -9.32917417e-02
  1.57221533e-01  1.23955827e-01  1.69706386e-01  6.28458804e-02
  2.53142305e-01  1.89036582e-01  4.00226182e-02  1.16499114e-01
 -9.68658641e-02  1.84393458e-01 -1.29831690e-01  8.48883264e-02
 -7.42644173e-02  9.89744333e-02  7.14745289e-02  3.21899949e-02
  1.73932285e-02  2.26697670e-01  1.35726220e-01  1.50887781e-01
 -1.82898487e-01  2.87747538e-02 -1.08682896e-01 -9.00064442e-03
 -1.72406798e-01 -1.92105774e-01 -1.22267460e-01 -5.81686678e-03
  3.10989163e-02  5.59818346e-03 -4.26548725e-02 -8.76043261e-02
  5.00441188e-02  3.55156964e-02 -7.94964448e-03 -3.24362540e-02
  3.15290991e-02  4.06513471e-02  1.65457846e-01  8.31335514e-02
  4.74402016e-02  1.35234928e-02  1.07568507e-01  6.37061403e-02
 -1.05323343e-01  2.34079122e-02  3.17265517e-02  6.90218910e-02
 -7.59793016e-02  5.09069335e-02  4.83609996e-02  1.30016145e-02
  1.75513611e-01  1.10979337e-01 -1.17937579e-02  3.52583058e-03
 -1.28181779e-01 -4.14026111e-02 -2.47185922e-01 -1.05398563e-01
 -2.16113735e-01  1.63976758e-01 -1.52091442e-02  8.29428027e-02
 -1.54646862e-01 -6.23414505e-02  3.93728887e-02  3.33062339e-02
  1.15438251e-01 -5.19043214e-02  2.31392679e-01  1.27576523e-01
 -8.98780334e-02  5.29201892e-02  4.02508703e-02  7.40061912e-02
  1.44462197e-01 -6.05372694e-02 -1.43486817e-01  5.05796102e-02
  1.54503254e-01  9.09434986e-02  1.15397283e-01  3.54002971e-02
  6.15072781e-02  1.69677496e-01 -4.35937810e-02 -8.41200975e-02
  6.42602376e-02  1.23516274e-01  1.96197269e-01 -8.20691148e-02
  4.07976812e-03  1.91944269e-01  4.22257526e-02  8.28031529e-02
  1.92181544e-01  8.99888016e-02  2.73886770e-03  1.45458457e-01
  1.45131259e-02  5.92500686e-02  2.23680875e-01 -9.80679772e-02
  8.68096365e-02 -2.12567888e-01 -1.40583022e-01  7.35269372e-02
  5.57164352e-02 -7.97126866e-02  9.79150570e-02  3.71103812e-02
  1.27759617e-01 -8.12472175e-02  1.48090086e-02 -4.88710359e-02
 -1.24548006e-01  1.10383811e-01  8.60712033e-02 -1.39315692e-01
  6.16504920e-02  4.69608894e-02  2.23465919e-02 -1.43144960e-01
  1.22769176e-01 -2.91146998e-01  3.89582708e-03  8.01743847e-02
 -1.21529409e-01 -2.44077997e-01  1.32286928e-01 -1.26766109e-01
 -4.35151375e-03  1.58357935e-01 -5.26782329e-02  5.10953632e-02
 -2.31957788e-01  7.66543052e-02 -7.72246781e-03  1.17372726e-01
  1.25382580e-01 -2.21051726e-01 -2.92529154e-01  3.68778044e-02
 -1.26029867e-01  1.21492649e-02  1.40779504e-01  8.23362073e-02
 -9.17491265e-02  1.13655603e-01  5.05413021e-02  1.49552148e-01
 -7.40656530e-02  3.87103266e-02  1.43244449e-02  3.86970098e-02
 -1.18868671e-01  7.57293746e-02 -1.50618526e-01 -3.82869574e-02
 -1.24631002e-01 -1.32244606e-02  2.12409958e-04 -6.40240046e-02
  1.54905930e-01  1.46542267e-01 -3.27342065e-02 -9.58355747e-02
  1.71355341e-01  1.06908814e-01 -4.54090724e-02  1.27019554e-01
  1.06058682e-01  3.28367048e-01 -9.86997791e-02  6.20614137e-02
 -5.82535396e-02  4.73201858e-02  1.07857891e-01  8.95513298e-03
  1.89696181e-01  5.91947402e-02 -2.30979476e-02  1.61394715e-01
 -3.07537424e-02 -1.30825215e-01  6.26812864e-03 -2.40680310e-02
 -2.60209007e-01  8.86952283e-02  1.64611413e-01 -8.73923820e-02
  1.66881673e-01 -6.65736422e-02  1.42201816e-01 -1.58247064e-02
 -1.17324560e-02 -2.29963096e-01 -1.21588967e-01 -1.78779200e-02
  7.93738478e-02  7.00850774e-02  1.92414653e-01  1.03311915e-01
 -9.33425235e-02 -2.10026596e-02 -1.74908915e-01 -9.46189302e-02
  2.19338049e-01  1.09793386e-01  1.54378135e-01 -3.07656677e-02
 -5.51875258e-02  2.12825544e-02 -2.27793089e-02 -1.16771672e-01
  1.39456701e-02  2.07863710e-01  1.91619067e-01 -2.86255996e-02
  3.66691942e-02 -5.66667754e-02  3.25125420e-01 -2.31463893e-02
 -6.78081237e-02  7.28349476e-02  3.09859126e-02 -7.73620820e-02
  1.41147345e-01  1.60774402e-01  7.48556669e-03  8.03264207e-02
 -3.04965135e-02 -1.25146214e-01  5.72605871e-03 -6.58435571e-03
 -5.80073502e-02  1.87410536e-01  1.14162795e-01  1.47636937e-01
  2.17942717e-02  1.40425723e-01  5.73065967e-02  1.30722551e-02
 -1.46627257e-01 -2.16129860e-01 -4.35938523e-02  1.43187069e-01
  1.27504220e-01 -2.16193847e-02 -1.25920576e-01 -4.44597808e-02
 -5.91805534e-02 -3.03648089e-02  3.77896146e-02  1.34765932e-01
 -1.57186796e-01  1.50709461e-01 -1.25444591e-01  1.01203443e-01
  1.34650465e-01  1.29183463e-02 -1.04420089e-01 -4.56423910e-02
  3.64473340e-02  6.60615947e-02  1.40592238e-01  9.73877022e-02
  3.97362893e-02  1.35316356e-02  4.68946071e-02 -1.21442327e-01
 -1.06581585e-01 -8.72870991e-03  4.36898069e-02  8.80580218e-02
 -1.04094986e-02  9.48618723e-02  1.87805711e-02  9.48720357e-02
 -2.30938419e-01 -1.36086011e-03  1.60973478e-02  1.81363362e-01
  1.66127059e-01 -1.36092420e-01  2.79059782e-02  1.05476725e-01
 -1.50561515e-01  1.17248966e-01  1.37489018e-01 -2.21281417e-01
  3.02344408e-04 -1.50787099e-01 -5.67884867e-02  3.81081952e-02
 -1.33750390e-01 -3.50595772e-02  1.03577581e-01 -3.31824061e-03
  1.20532916e-01 -3.81878174e-01 -7.54892288e-02 -1.72322991e-01
  1.35813533e-02  9.48713538e-02 -2.12762633e-02  2.19678561e-01
 -2.42252303e-01 -1.15704210e-01 -1.48369844e-03  6.52965485e-02
  6.38131021e-02  4.17122853e-02 -7.61673624e-02 -8.37611178e-02
 -8.50498631e-02 -1.88154690e-02 -4.21795369e-02  3.47064640e-02
  5.77279267e-02  2.48269130e-02  4.39692672e-02 -1.31235421e-01
 -2.54341345e-01  6.80526364e-02  8.34737414e-02 -4.63941797e-03
  1.79419485e-02  7.82997771e-02 -5.04494003e-02 -1.03251959e-01
  3.00185008e-01 -4.95542904e-02 -1.75960632e-01  1.82799361e-01
 -1.17882140e-01  1.90905241e-01  1.87304827e-01  7.75802108e-02
 -1.12863491e-01  2.15391480e-02  1.28834444e-01  4.47913921e-02
 -8.58652304e-02 -1.16247011e-01 -2.79785313e-02  3.68407854e-02
  1.09211315e-01 -1.98589772e-02 -1.36987145e-01 -8.41247461e-03
  2.86111396e-02  2.22105933e-02 -6.24533426e-02  3.54031355e-02
 -6.12892575e-02 -5.70728229e-02 -3.72833901e-02 -1.30805256e-01
 -8.43997093e-02  2.81295860e-01  1.38329746e-01 -1.45396456e-01
  1.62211141e-01 -8.24360748e-02 -1.04250383e-01 -4.54866262e-02
 -3.86525848e-02  9.99185909e-02 -1.23922234e-01  2.14669035e-01
  7.56259458e-02 -1.94778280e-01 -1.30808893e-01 -1.86855416e-01
  1.34116079e-01  9.03336971e-02  1.03046902e-01 -8.04189963e-02
  1.38866469e-01  5.52094458e-02  7.23564627e-02  1.21122781e-01
 -1.00960097e-01 -9.43107959e-02 -8.79738740e-02  4.34702776e-03
  3.89866823e-03  7.97535764e-02  1.35820977e-01  2.35113174e-02
  2.63540676e-03 -7.17171232e-02 -2.21171134e-02 -5.23557422e-02
  7.60709397e-02  4.28760472e-03  2.44348558e-01 -1.26387565e-02
 -1.60394881e-02 -9.84835143e-02 -3.70367006e-02  1.25520883e-02
  1.67003727e-02  8.26456874e-02 -4.49271778e-02  8.42750188e-02
 -6.54822706e-02 -2.46951909e-03 -1.05153090e-01 -1.82123622e-01
 -1.12998478e-02 -8.84098809e-02 -1.25152013e-01 -7.66165350e-02
 -9.97448233e-02 -9.62482751e-02  1.95394598e-01  1.70309723e-01
 -1.59377701e-01 -1.95059944e-01  8.05459647e-02 -1.03828863e-01
 -3.54468224e-02 -6.29922144e-02 -5.59982755e-02  9.42109432e-02
  6.56107344e-02  1.45893517e-01 -1.02923370e-01  1.90396754e-01
  3.45475346e-02 -6.61857937e-02 -5.78199246e-02 -2.13382053e-03
 -6.01890663e-02  2.05773049e-01  2.29034371e-01  8.00300694e-02
 -2.62622445e-01  5.97598851e-02  1.37743440e-01 -4.35836775e-02
  2.94699118e-03 -8.42902727e-02  3.44408557e-02  2.98737891e-02
  2.48815227e-02  7.08940178e-02  8.95116312e-02 -1.95732052e-02
  2.60254222e-02  1.22457864e-01 -1.87593007e-01  1.50857930e-01
  2.64159378e-02 -6.54221746e-02  6.91774550e-02  7.73571687e-02
  6.31879361e-02 -2.10712188e-02  2.96409766e-03 -1.85884842e-01
 -1.33677583e-01 -5.67982069e-02 -1.46997925e-02  7.64317824e-04
 -2.13379689e-01 -4.46067577e-02 -1.46950387e-02  6.88940654e-02]

4、用机器学习算法进行预测
由于500维数据的每个值是连续的,不是单独考虑,也即共线性关系,所以我们不太适合用类似RandomForest这类把每个column当作当度的variable来看的方法。我们这里用SVM线性分割器。此处有两个版本,一个是是用交叉验证评分,一个是用了GridSearchCV方法来选定参数,都是用ROC来评分。
(1)交叉验证cross_val_score:

from sklearn.svm import SVR
from sklearn.model_selection import cross_val_score

params = [0.1,0.5,1,3,5,7,10,12,16,20,25,30,35,40]
test_scores = []
for param in params:
    clf = SVR(gamma=param)
    test_score = cross_val_score(clf, X_train, y_train, cv=3, scoring='roc_auc')
    test_scores.append(np.mean(test_score))
import matplotlib.pyplot as plt
%matplotlib inline  #%matplotlib 当你调用matplotlib.pyplot的绘图函数plot()进行绘图的时候,或者生成一个figure画布的时候,可以直接在你的python console里面生成图像,使用jupyter notebook 或者 jupyter qtconsole进行编辑的时候才需要用到。
plt.plot(params,test_scores)
plt.title("Param vs CV AUC Score")

结果如下:

可以发现当gamma的值取16左右得到的roc效果最好。
(2)基于GridSearchCV(CV表示集成了交叉验证)

from sklearn.svm import SVC
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import GridSearchCV

parameters = {
    'gamma':[0.1,0.5,1,3,5,7,10,12,16,20,25,30,35,40],
    'C':[0.8,1,10]
}
test_scores = []
estimator = SVC(probability=True)
clf = GridSearchCV(estimator,parameters,cv=10,n_jobs=4)
clf.fit(X_train,y_train)
y_prediction = clf.predict_proba(X_test)
print(roc_auc_score(y_test,y_prediction[:,1]))
print(clf.best_params_)

结果如下:

0.4718021953405018
{'C': 1, 'gamma': 20}

发现结果也并不是很理想,因为此处我们直接取平均vec值做为向量来计算的,实际绝对不会这么简单,实际中比如用vector表示出matrix,用cnn做“降维+注意力”等。

猜你喜欢

转载自blog.csdn.net/m0_38088359/article/details/82860617