信息检索第六次实验--NB算法的训练及分类过程

目录

朴素贝叶斯算法概述:

朴素贝叶斯算法流程:

代码实现:

结果展示:


朴素贝叶斯算法概述:

        在多项式模型中, 设某文档d=(t_{1},t_{2},…,t_{k}),t_{k}是该文档中出现过的单词,允许重复。

先验概率P(c)= c类文档数量/整个训练样本的文档总数;

类条件概率P(t_{k}|c)=(单词t_{k}在各个c类文档中出现过的次数之和+1)/(类c下单词总数+|V|);

  • P(c)可以认为是c类文档在所有文档中所占的比例
  • P(\bar{c})可以认为是非c类文档在所有文档中所占的比例
  • V是训练样本的单词表,|V|则表示训练样本包含的单词种类数
  • P(t_{k}|c)可以看作是单词t_{k}在证明文档d属于c类文档的条件概率
  • P(t_{k}|\bar{c})可以看作是单词t_{k}在证明文档d属于非c类文档的条件概

  • 最后根据事件的独立性可得,该测试文档中所有单词在该类别之下的条件概率之积为总的条件概率,再与先验概率计算乘积获得最大概率所在的类别

  

朴素贝叶斯算法流程

提取所有文档中的词条并进行去重\Rightarrow获取文档的所有类别\Rightarrow计算每个类别中的文档数目

先验概率计算:

对每篇训练文档: 
    对每个类别: 
        如果词条出现在文档中-->增加该词条的计数值
        增加所有词条的计数值

条件概率计算:

对每个类别: 
    对每个词条: 
        将该词条的数目除以总词条数目得到的条件概率(P(词条|类别))

总概率:

返回该文档属于每个类别的条件概率

代码实现:

# cn:统计c类的文档数量、lcn:统计所有训练集的文档数量、df:分类统计所有的文档内容信息
cn, lcn, df = 0, 0, {1: [], 0: []}

# 训练数据集的输入,每次输入之后询问训练集数据是否输入结束,结束之后直接跳出循环
while(1):
    d = input('请输入训练文档:').split(' ')            # 将输入的文档进行词条化
    c = eval(input('属于China类输入1,不属于输入0:'))   # 根据所给的标签对每篇文档进行类别统计【该类别和非该类别(2种)】
    if c == 1:
        cn += 1  # 统计c类的文档数量
    lcn += 1     # 统计所有的文档数量
    df[c] = df[c] + d  # 将每种类别的文档内容进行合并
    flag = input('要继续输入吗?[y|n]:')  # 循环判断训练集是否输入结束
    if flag == 'n':  # 结束,则跳出训练集信息的输入
        break

testd = input('请输入测试文档:').split()  # 输入测试文档内容信息

# df = {1: ['Chinese', 'Beijing', 'Chinese', 'Chinese', 'Chinese', 'Shanghai', 'Chinese', 'Macao'], 0: ['Tokyo', 'Japan', 'Chinese']}
# testd = ['Chinese', 'Chinese', 'Chinese', 'Tokyo', 'Japan']

testdin = set(testd)  # 将测试文档放在集合中【作用:去重】
# 计算先验概率
Pc = cn / lcn         # 计算 C类的概率
FPc = 1 - Pc          # 计算非 C类的概率

Pctestd, FPctestd = 1, 1    # 初始化输出概率的初值为 1
# lc:c类文档的词条数量、Flc:非c类文档的词条数量、l:所有文档中的词条数量
lc, Flc, l = len(df[1]), len(df[0]), len(set(df[1] + df[0]))
for i in testdin:        # 遍历测试文档中所有的词项
    mi = testd.count(i)  # 统计测试文档中词条的数量
    # 计算条件概率
    Pctestd *= ((df[1].count(i) + 1) / (lc + l)) ** mi    # c类每个元素的概率乘积,使用了+1平滑
    FPctestd *= ((df[0].count(i) + 1) / (Flc + l)) ** mi  # 非c类每个元素的概率乘积,使用了+1平滑

# 计算输出概率(先验概率*条件概率)
Pctestd = round(Pctestd * Pc, 4)     # 保留小数点后4位
FPctestd = round(FPctestd * FPc, 4)  # 保留小数点后4位
flag = 'China' if Pctestd > FPctestd else '非China'    # 比较获得大概率的类别并输出
print('\n计算属于China的概率为{},不属于China类的概率为{},\n所以分类器将测试文档分到{}类。'.format(Pctestd, FPctestd, flag))

结果展示:

猜你喜欢

转载自blog.csdn.net/rui_qi_jian_xi/article/details/130509334