[Basics of machine learning] Naive Bayes classifies the barrage of station B (whether or not it is permanently banned)

Naive Bayes Classification


Preface

Naive Bayes classification is an application of Bayesian probability theory. It can classify data and is a very commonly used classification algorithm.

1. Use conditional probability for classification

Calculate the conditional probabilities under different categories, and which category has the higher probability will be assigned to the category with the higher probability.
Description

  • In the actual calculation, the logarithm of the probability is used for calculation, so as to avoid underflow of the multiplication.
  • Initialization is to initialize the numerator and denominator of the probability to be not 0, but ensure that the denominator>numerator. This avoids zero value and division by zero.
  • When comparing different conditional probabilities, you can only calculate the numerator (the denominator is the same).

2. Obtain the barrage data of the small black house at station B

See the blog crawling the small black house
at station B. The data obtained is the following json format

[
{'type': xxx     	// 表示封禁信息
'article': xxx}		//表示实际弹幕信息
]

There are 2108 sets of data obtained in this article. The first 2050 sets are used as the training set, and the remaining data is used as the test data set.
Since the word division method is to directly divide Chinese characters and symbols (not reasonable enough), the error rate is also relatively high (about 25%)

Three, the code

import numpy as np
import json

class Bayes():
    
    def __init__(self):
        pass

    def loadData(self, url):
        #导入json文件
        with open(url, 'r', encoding='utf-8') as f:
            data_dect = json.load(f)
        # print(data_dect)
        
        label = []
        posting_list = []
        self.data_size = len(data_dect)

        for item in data_dect:    
            label.append(item['type'])
            posting_list.append([word for word in item['article']])

        return_label = []

        # 分为永久封禁和非永久封禁
        for example in label:
            if example == '永久封禁':
                return_label.append(1)
            else:
                return_label.append(0)
        # print(label)
        # print(posting_list)
        return posting_list, return_label

    def creat_vocabulary_list(self, data_set):
        # 创建单词表
        vocabulary_list = set([])
        for document in data_set:
            # 合并两个集合
            vocabulary_list = vocabulary_list | set(document)
        
        return list(vocabulary_list)

    def is_word_in_vocab(self, vocab_list, input_set):
        # 某句话中是否有某个单词 构建词向量
        return_vector = [0] * len(vocab_list)

        for word in input_set:
                if word in vocab_list:
                    return_vector[vocab_list.index(word)] = 1
                else:
                    print("Word: %s have not deteced!" % word)
        return return_vector

    def train_naive_bayes(self, train_set, train_category_set):
        # 训练模型,返回先验概率,及
        # 永久封禁的总概率
        # 在永久封禁弹幕的条件下,词向量表中每个单词出现的概率
        # 在非永久封禁弹幕条件下,词向量表中每个带刺出现的概率
        # train_set 文件词向量矩阵
        # train_category_set 文档类型列表


        num_train_document = len(train_set)
        num_words = len(train_set[0])

        # 永久封禁弹幕的总概率
        p_abusive = sum(train_category_set) / num_train_document

        # 在所有文档某个单词总出现次数(永久/非永久)
        p_abusive_num = np.ones(num_words)
        p_unabusive_num = np.ones(num_words)

        
        # 在所有文档中单词出现总次数(永久/非永久)
        p_abusive_all_num = 2.0
        p_unabusive_all_num = 2.0

        for i in range(num_train_document): 
            if train_category_set[i] == 1: # 永久封禁弹幕
                p_abusive_num += train_set[i]
                p_abusive_all_num += np.sum(train_set[i])
            else:
                p_unabusive_num += train_set[i]
                p_unabusive_all_num += np.sum(train_set[i])

        # 永久封禁弹幕 计算log, 防止乘法下溢
        p_abusive_vec = np.log(p_abusive_num/p_abusive_all_num)
        # 非永久封禁弹幕
        p_unabusive_vec = np.log(p_unabusive_num/p_unabusive_all_num)

        # print(p_unabusive_vec, p_abusive_vec, p_abusive)
        return p_abusive, p_abusive_vec, p_unabusive_vec

    def classify(self, input_x, p_abusive, p_abusive_vec, p_unabusive_vec):
        # input_x 为一个词向量
        # 这里可以只比较分子, 分母相同

        p1 = np.sum(input_x * p_abusive_vec) + np.log(p_abusive)
        p0 = np.sum(input_x * p_unabusive_vec) + np.log(1.0 - p_abusive)

        # print(p1, p0)
        if p1 > p0:
            return 1
        else:
            return 0

    def test(self):
        # 测试
        URL = r'2020\ML\ML_action\\3.NaiveBayes\data\blackroom.json'
        posting_list, class_vec = self.loadData(URL)
        vocab_list = self.creat_vocabulary_list(posting_list)
        print("len = ", len(posting_list))
        train_size = 2050
        # 划分数据
        test_list = posting_list[train_size:]
        test_label = class_vec[train_size:]
        posting_list = posting_list[:train_size]
        class_vec = class_vec[:train_size]

        train_set = []
        for posting_document in posting_list:
            train_set.append(self.is_word_in_vocab(vocab_list, posting_document))
        # print(train_set, class_vec)

        index = 0
        rate = 0
        pAb, p1v, p0v = self.train_naive_bayes(np.array(train_set), np.array(class_vec))
        print(pAb, p1v, p0v)
        for example in test_list:
            test_input_data = np.array(self.is_word_in_vocab(vocab_list, example))
            print(example)
            # print(test_input_data)
            test_result = self.classify(test_input_data, pAb, p1v, p0v)
            # print("Bayes send back: %s, real class %s" % (test_result, test_label[index]))
            if test_result != test_label[index]:
                rate += 1
            index += 1

        print("error rate: %f" % float(rate/len(test_label)))
     
DEBUG = True
nb = Bayes()
if DEBUG:
    nb.test()

references

  1. Machine learning practical books
  2. https://github.com/apachecn/AiLearning/blob/master/docs/ml/4.%E6%9C%B4%E7%B4%A0%E8%B4%9D%E5%8F%B6%E6%96%AF.md
  3. https://www.cnblogs.com/jpcflyer/p/11069659.html

Guess you like

Origin blog.csdn.net/qq_37753409/article/details/108968501