X_train 就是自己的训练语料
“”“
date:2018_7_25
doc2vec计算句子相似性
”“”
# coding:utf-8
import sys
import time
import csv
import glob
import gensim
import sklearn
import numpy as np
import jieba.posseg as pseg
import jieba
from gensim.models.doc2vec import Doc2Vec, LabeledSentence
TaggededDocument = gensim.models.doc2vec.TaggedDocument
def loadPoorEnt(path2 = 'G:/project/sentimation_analysis/data/stopwords.csv'):
csvfile = open(path2,encoding='UTF-8')
stopwords = [line.strip() for line in csvfile.readlines()]
return stopwords
stop_words = loadPoorEnt()
def cut(data):
result=[] #pos=['n','v']
res = pseg.cut(data)
list = []
for item in res:
#if item.word not in stop_words and (item.flag == 'n' or item.flag == 'a' or item.flag == 'v'):
if item.word not in stop_words :
list.append(item.word)
result.append(list)
return result
def get_all_content():
#abel_dir = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
all_files = glob.glob(r'D:/GFZQ/GFZQ/xuesu2018/xuesu/*.csv')
return all_files
def get_wenben(path):
csvfile = open(path,'r',encoding='UTF-8')
reader = csv.reader(csvfile)
return reader
def get_QA(wenben):
Q_all =[]
A_all =[]
for QA in wenben :
Q_all.append(QA[1])
A_all.append(QA[2])
all = Q_all + A_all
return all,Q_all,A_all
def get_datasest(all_csv):
docs = all_csv
print( len(docs))
x_train = []
# y = np.concatenate(np.ones(len(docs)))
all_sent = []
for file_one in docs:
for sent in file_one:
#print (sent)
all_sent.append(sent)
for i,text in enumerate(all_sent):
word_list = cut(text)
#print(word_list[0])
l = len(word_list[0])
print (l)
document = TaggededDocument(word_list[0], tags=[i])
x_train.append(document)
return x_train
def getVecs(model, corpus, size):
vecs = [np.array(model.docvecs[z.tags[0]].reshape(1, size)) for z in corpus]
return np.concatenate(vecs)
def train(x_train, size=200, epoch_num=1):
model_dm = Doc2Vec(x_train, min_count=1, window=3, size=size, sample=1e-3, negative=5, workers=4)
model_dm.train(x_train, total_examples=model_dm.corpus_count, epochs=70)
model_dm.save('G:/project/sentimation_analysis/data/conference.model')
return model_dm
def get_csvfile ():
all_files = get_all_content()
length = 28 # len(all_files)
print ("统计了%d家公司的情感词" %length)
all_csv = []
for i in range(length):
print ("正在解析第%d家公司" %i)
file_one = all_files[i]
wenben = get_wenben(file_one)
all, Q_all, A_all = get_QA(wenben)
all_csv.append(all)
return all_csv
def stest():
model_dm = Doc2Vec.load('G:/project/sentimation_analysis/data/conference_model.csv')
test_text = ["我们是好孩子"]
inferred_vector_dm = model_dm.infer_vector(test_text)
# print (inferred_vector_dm)
sims = model_dm.docvecs.most_similar([inferred_vector_dm], topn=10)
return sims
if __name__ == '__main__':
start = time.clock()
all_csv = get_csvfile()
x_train = get_datasest(all_csv)
model_dm = train(x_train)
sims = stest()
for count, sim in sims:
sentence = x_train[count]
print ( sentence, sim, )