基于sklearn库,搭建一个简单的问答系统

第一部分: 在这部分里,首先需要去读取给定的文件,并把文件里的内容读取到list里面。这部分的任务主要需要文件IO操作方面的基本知识。

# 读取文件
def read_corpus(file):
    with open(file) as f:
        list = []
        lines = f.readlines()
        for i in lines:
            list.append(i)
    return list
questions = read_corpus('./Question_combined.dat')
answers = read_corpus('./Answer_combined.dat')
assert len(questions)==len(answers), "问题和答案列表的大小不一样,请检查读入数据是否有误!"

第二部分: 处理已有的字符串数据,并把它们转换成词袋向量。这部分内容涉及到一些简单的字符串预处理技术(比如过滤掉一些没用的字符、分词等),还有就是基于sklearn的把字符串转换向量的过程。本部分的内容需要字符串操作、分词、词袋模型相关的基础知识。

import re
import jieba
from sklearn.feature_extraction.text import CountVectorizer

def filter_out_category(input):
    new_input = re.sub('[\u4e00-\u9fa5]{2,5}\\/','',input)
    return new_input

def filter_out_punctuation(input):
    new_input = re.sub('([a-zA-Z0-9])','',input)
    new_input = ''.join(e for e in new_input if e.isalnum())
    return new_input

def word_segmentation(input):
    new_input = ','.join(jieba.cut(input))
    return new_input

def conver2BOW(data):
    new_data = [] 
    for q in data:
        q = filter_out_category(q)  
        q = filter_out_punctuation(q)
        q = word_segmentation(q)
        new_data.append(q)
    vectorizer = CountVectorizer() 
    X = vectorizer.fit_transform(new_data)
    return vectorizer, X
vectorizer, X = conver2BOW(questions)

第三部分: 对于用户的新输入,返回答案。 这是最后一部分,也就是等我们创建完词袋向量之后,我们就可以输入一些新的问题,然后从库中找出最合适的答案。这部分的任务涉及到余弦相似度、简单搜索排序等方面基础知识。

import numpy as np
def idx_for_largest_cosine_sim(input, questions):
    list = []
    input = (input.toarray())[0]
    for question in questions:
        question = question.toarray()
        num = float(np.matmul(question, input))
        denom = np.linalg.norm(question) * np.linalg.norm(input)
        cos = num / denom
        list.append(cos)

    best_idx = list.index(max(list))
    return best_idx

def answer(input):
    input = filter_out_punctuation(input)
    input = word_segmentation(input)
    bow = vectorizer.transform([input])
    best_idx = idx_for_largest_cosine_sim(bow, X)
    return answers[best_idx]

输入问题,查看结果

print(answer("谁知道网上找兼职工作的网站"))

搜索结果如下:

这里没有对返回数据进行过清洗,否则体验会更好一些…

源码及测试数据已上传至git,点击这里可直接查看,有疑问的同学可提Issues或在博客下方留言~

猜你喜欢

转载自blog.csdn.net/lt326030434/article/details/82909589