import pandas as pd
import numpy as np
def fit(df,lambda_=0):
class_name = df.columns.to_list()[-1]
feature_names = df.columns.to_list()[:-1]
# 类标记计数
N = df.shape[0]
class_count = df[class_name].value_counts().to_dict()
# 先验概率,使用的是贝叶斯估计
pri_prob = {}
for calss_key, calss_value in class_count.items():
pri_prob[calss_key] = (class_count[calss_key] + lambda_)/(N + len(class_count) * lambda_)
# 特征计数和条件概率,使用的是贝叶斯估计
feature_count = {}
con_prob = {}
for calss_key, calss_value in class_count.items():
feature = {}
feature_prob = {}
for column in range(len(feature_names)):
feature[feature_names[column]] = df[df[class_name] == calss_key][feature_names[column]].value_counts().to_dict()
for feature_key, feature_value in feature.items():
feature_value_prob = {}
for feature_value_key, feature_value_value in feature_value.items():
feature_value_prob[feature_value_key] = (feature_value[feature_value_key] + lambda_) / (class_count[calss_key] + lambda_ * len(feature_value))
feature_prob[feature_names[column]] = feature_value_prob
feature_count[calss_key] = feature
con_prob[calss_key] = feature_prob
return pri_prob, con_prob
def predict(pri_prob, con_prob, dict):
for i in range(len(dict)):
class_prob = {}
for class_key, class_value in pri_prob.items():
prob_list = [pri_prob[class_key]] # 存放先验概率和条件概率
for data_key, data_value in dict[i].items():
prob_list.append(con_prob[class_key][data_key][data_value])
class_prob[class_key] = np.prod(prob_list) # 计算概率
dict[i]['Y'] = max(class_prob, key=class_prob.get) # 取最大概率对应的类标记
print(i, dict[i])
if __name__=='__main__':
df = pd.read_excel('test.xlsx')
pri_prob, con_prob = fit(df, lambda_=1)
df_predict = pd.read_excel('predict.xlsx').T.to_dict()
predict(pri_prob, con_prob, df_predict)
结果如下:
0 {'X1': 2, 'X2': 'S', 'Y': -1}
1 {'X1': 3, 'X2': 'L', 'Y': 1}
代码的难点在计算条件概率和求解目标函数。
朴素贝叶斯分类器主要理解条件概率意义、条件独立性假设以及贝叶斯估计的作用。
训练数据('test.xlsx')如下(可复制粘贴到xlsx):
X1 | X2 | Y |
1 | S | -1 |
1 | M | -1 |
1 | M | 1 |
1 | S | 1 |
1 | S | -1 |
2 | S | -1 |
2 | M | -1 |
2 | M | 1 |
2 | L | 1 |
2 | L | 1 |
3 | L | 1 |
3 | M | 1 |
3 | M | 1 |
3 | L | 1 |
3 | L | -1 |
测试数据('predict.xlsx')如下(可复制粘贴到xlsx):
X1 | X2 |
2 | S |
3 | L |