python实现贝叶斯推断——垃圾邮件分类

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/AlanConstantineLau/article/details/71694660

理论

理论强推阮一峰大神的个人网站
1.贝叶斯推断及其互联网应用(一):定理简介
2.贝叶斯推断及其互联网应用(二):过滤垃圾邮件
非常简明易懂,然后我下面的代码就是实现上面过滤垃圾邮件算法的。

前期准备

数据来源

数据来源于《机器学习实战》中的第四章朴素贝叶斯分类器的实验数据。数据书上只提供了50条数据(25条正常邮件,25条垃圾邮件),感觉数据量偏小,以后打算使用scikit-learn提供的iris数据。

数据准备

和很多机器学习一样,数据需要拆分成训练集和测试集。
拆分训练集和测试集的思路如下:
1.遍历包含50条数据的email文件夹,获取文件列表
2.使用random.shuffle()函数打乱列表
3.截取乱序后的文件列表前10个文件路径,并转移到test文件夹下,作为测试集。
代码实现:

# -*- coding: utf-8 -*-
# @Date     : 2017-05-09 13:06:56
# @Author   : Alan Lau ([email protected])
# @Language : Python3.5

# from fwalker import fun
import random
# from reader import writetxt, readtxt
import shutil
import os

def fileWalker(path):
    fileArray = []
    for root, dirs, files in os.walk(path):
        for fn in files:
            eachpath = str(root+'\\'+fn)
            fileArray.append(eachpath)
    return fileArray

def main():
    filepath = r'..\email'
    testpath = r'..\test'
    files = fileWalker(filepath)
    random.shuffle(files)
    top10 = files[:10]
    for ech in top10:
        ech_name = testpath+'\\'+('_'.join(ech.split('\\')[-2:]))
        shutil.move(ech, testpath)
        os.rename(testpath+'\\'+ech.split('\\')[-1], ech_name)
        print('%s moved' % ech_name)


if __name__ == '__main__':
    main()

*对代码中的fwalker、reader两个包有疑问的请前往python中import自己写的.pypython3文本读取与写入常用代码

最后获取的文件列表如下:


copy是备份数据,防止操作错误
ham文件列表:


spam文件列表:

test文件列表:


可见,数据准备后的测试集中,有7个垃圾邮件,3个正常的邮件。

代码实现

# -*- coding: utf-8 -*-
# @Date     : 2017-05-09 09:29:13
# @Author   : Alan Lau ([email protected])
# @Language : Python3.5

# from fwalker import fun
# from reader import readtxt
import os


def readtxt(path, encoding):
    with open(path, 'r', encoding = encoding) as f:
        lines = f.readlines()
    return lines

def fileWalker(path):
    fileArray = []
    for root, dirs, files in os.walk(path):
        for fn in files:
            eachpath = str(root+'\\'+fn)
            fileArray.append(eachpath)
    return fileArray

def email_parser(email_path):
    punctuations = """,.<>()*&^%$#@!'";~`[]{}|、\\/~+_-=?"""
    content_list = readtxt(email_path, 'utf8')
    content = (' '.join(content_list)).replace('\r\n', ' ').replace('\t', ' ')
    clean_word = []
    for punctuation in punctuations:
        content = (' '.join(content.split(punctuation))).replace('  ', ' ')
        clean_word = [word.lower()
                      for word in content.split(' ') if len(word) > 2]
    return clean_word


def get_word(email_file):
    word_list = []
    word_set = []
    email_paths = fileWalker(email_file)
    for email_path in email_paths:
        clean_word = email_parser(email_path)
        word_list.append(clean_word)
        word_set.extend(clean_word)
    return word_list, set(word_set)


def count_word_prob(email_list, union_set):
    word_prob = {}
    for word in union_set:
        counter = 0
        for email in email_list:
            if word in email:
                counter += 1
            else:
                continue
        prob = 0.0
        if counter != 0:
            prob = counter/len(email_list)
        else:
            prob = 0.01
        word_prob[word] = prob
    return word_prob


def filter(ham_word_pro, spam_word_pro, test_file):
    test_paths = fileWalker(test_file)
    for test_path in test_paths:
        email_spam_prob = 0.0
        spam_prob = 0.5
        ham_prob = 0.5
        file_name = test_path.split('\\')[-1]
        prob_dict = {}
        words = set(email_parser(test_path))
        for word in words:
            Psw = 0.0
            if word not in spam_word_pro:
                Psw = 0.4
            else:
                Pws = spam_word_pro[word]
                Pwh = ham_word_pro[word]
                Psw = spam_prob*(Pws/(Pwh*ham_prob+Pws*spam_prob))
            prob_dict[word] = Psw
        numerator = 1
        denominator_h = 1
        for k, v in prob_dict.items():
            numerator *= v
            denominator_h *= (1-v)
        email_spam_prob = round(numerator/(numerator+denominator_h), 4)
        if email_spam_prob > 0.5:
            print(file_name, 'spam', email_spam_prob)
        else:
            print(file_name, 'ham', email_spam_prob)
        # print(prob_dict)
        # print('******************************************************')
        # break


def main():
    ham_file = r'..\email\ham'
    spam_file = r'..\email\spam'
    test_file = r'..\email\test'
    ham_list, ham_set = get_word(ham_file)
    spam_list, spam_set = get_word(spam_file)
    union_set = ham_set | spam_set
    ham_word_pro = count_word_prob(ham_list, union_set)
    spam_word_pro = count_word_prob(spam_list, union_set)
    filter(ham_word_pro, spam_word_pro, test_file)


if __name__ == '__main__':
    main()

实验结果

ham_24.txt ham 0.0005
ham_3.txt ham 0.0
ham_4.txt ham 0.0
spam_11.txt spam 1.0
spam_14.txt spam 0.9999
spam_17.txt ham 0.0
spam_18.txt spam 0.9992
spam_19.txt spam 1.0
spam_22.txt spam 1.0
spam_8.txt spam 1.0

可见正确率为90%,实际上严格来说,应当将所有数据随机均分十组,每一组轮流作为一次测试集,剩下九组作为训练集,再将十次计算结果求均值,这个模型求出的分类效果才具有可靠性,其次,数据量小导致准确率较小的原因不排除在外。

所有代码以及数据GITHUB

猜你喜欢

转载自blog.csdn.net/AlanConstantineLau/article/details/71694660