TextRank Spark 实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/mzg12345678/article/details/78611861
#! -*- coding:utf-8 -*-
import os
import sys
import math
from pyspark import SparkContext
from pyspark import StorageLevel
from pyspark.conf import SparkConf 


#os.environ['PYSPARK_PYTHON'] = './pypy/bin/pypy'
import numpy as np


reload(sys)
sys.setdefaultencoding('utf8')
def cosin(vec1,vec2):
    if len(vec1) != len(vec2):
        return 0
    
    norm1 = 0.0
    for t in vec1:
        norm1 += t*t
    norm1 = math.sqrt(norm1)    
    
    norm2 = 0.0
    for t in vec2:
        norm2 += t*t
    norm2 = math.sqrt(norm2)    


    if norm1 < 0.0000001 or norm2 < 0.0000001:
        return 0


    sum = 0.0
    for i in range(len(vec1)):
        sum += vec1[i]*vec2[i]


    return sum*1.0/(norm1*norm2)            




def cal_sim(word_vec,w1,w2):
    mdict = word_vec.value 


    if w1 not in mdict:
        return 0.0
    if w2 not in mdict:
        return 0.0
    return cosin(mdict[w1],mdict[w2])


class TextRank(object):  
      
    def __init__(self, sentence, window, alpha, iternum):  
        self.sentence = sentence  
        self.window = window  
        self.alpha = alpha  
        self.edge_dict = {} #记录节点的边连接字典  
        self.iternum = iternum#迭代次数  
  
    #对句子进行分词  
    def cutSentence(self):  
        self.word_list = self.sentence.strip().split(" ")  
  
    #根据窗口,构建每个节点的相邻节点,返回边的集合  
    def createNodes(self):  
        tmp_list = []  
        word_list_len = len(self.word_list)  
        for index, word in enumerate(self.word_list):  
            if word not in self.edge_dict.keys():  
                tmp_list.append(word)  
                tmp_set = set()  
                left = index - self.window + 1#窗口左边界  
                right = index + self.window#窗口右边界  
                if left < 0: left = 0  
                if right >= word_list_len: right = word_list_len  
                for i in range(left, right):  
                    if i == index:  
                        continue  
                    tmp_set.add(self.word_list[i])  
                self.edge_dict[word] = tmp_set  
  
    #根据边的相连关系,构建矩阵  
    def createMatrix(self,word_vec):  
        self.matrix = np.zeros([len(set(self.word_list)), len(set(self.word_list))])  
        self.word_index = {}#记录词的index  
        self.index_dict = {}#记录节点index对应的词  
  
        for i, v in enumerate(set(self.word_list)):  
            self.word_index[v] = i  
            self.index_dict[i] = v  
        for key in self.edge_dict.keys():  
            for w in self.edge_dict[key]:  
                sim = cal_sim(word_vec,key,w)
                if sim < 0.1:
                    sim = 0.0
                self.matrix[self.word_index[key]][self.word_index[w]] = sim 
                self.matrix[self.word_index[w]][self.word_index[key]] = sim
        #归一化  
        for j in range(self.matrix.shape[1]):  
            sum = 0  
            for i in range(self.matrix.shape[0]):  
                sum += self.matrix[i][j]
            for i in range(self.matrix.shape[0]):  
                if sum > 0.001:
                    self.matrix[i][j] /= sum  
  
    #根据textrank公式计算权重  
    def calPR(self):  
        self.PR = np.ones([len(set(self.word_list)), 1])  
        for i in range(self.iternum):  
            self.PR = (1 - self.alpha) + self.alpha * np.dot(self.matrix, self.PR)  
  
    #输出词和相应的权重  
    def getResult(self):  
        word_pr = {}  
        for i in range(len(self.PR)):  
            word_pr[self.index_dict[i]] = self.PR[i][0]  
        res = sorted(word_pr.items(), key = lambda x : x[1], reverse=True)[:20]  
        #print(res)
        
        out =""
        for (k,v) in res:
            out += "%s\t%f " %(k,v)
        return out.strip()            
  


if __name__=="__main__":
    os.environ['PYSPARK_PYTHON'] = './pypy/bin/pypy'
    def extract_key(line,word_vec):
        tlist = line.split("\t")
        if len(tlist) != 2:
            return query
        query = tlist[0]
        text = tlist[1]


        tr = TextRank(text, 5, 0.85, 50)
        tr.cutSentence()  
        tr.createNodes()  
        tr.createMatrix(word_vec)  
        tr.calPR()  
        return query+"\t" + tr.getResult()  
    def parse_vector(line):
        tlist = line.strip().split(" ")
        return (tlist[0],map(float,tlist[1:]))


    conf = SparkConf().setAppName("TextRank")


    sc = SparkContext(conf=conf)
    input = sc.textFile("")




    word_vec_raw = sc.textFile("") \
                    .filter(lambda line:len(line.strip().split(" ")) == 201) \
                    .map(lambda line: parse_vector(line)).collectAsMap()
          
    word_vec = sc.broadcast(word_vec_raw)


    
    input.map(lambda line: extract_key(line,word_vec)).saveAsTextFile("")
    


    sc.stop()

猜你喜欢

转载自blog.csdn.net/mzg12345678/article/details/78611861
今日推荐