持续创作,加速成长!这是我参与「掘金日新计划 · 6 月更文挑战」的第11天,点击查看活动详情
有些回归算法也可以当作分类算法,就比如接下来要讲的逻辑回归(Logistic Regression)。别看名字有个回归,但其实常用于二元分类中,所以是一个分类算法。
逻辑回归
逻辑回归在分类时,不是直接返回莫一个类,而更多是返回是某一个类的概率。那么这个估计概率是怎么计算的呢?
在线性回归中,我们的预测值 ,而逻辑函数将原本的预测值作为参数,代入新公式中改,最后返回一个估计概率。公式如下:
这个 ,叫sigmoid函数,他会输出0-1之间的数值,他的公式和图像如下所示:
图1 sigmoid函数
最后预测的结果就很容易的,如果 的概率大于0.5,则最后的预测为正类(返回为1),否则预测为负类(返回为0)。预测值返回如下:
成本函数
既然我们有了估算概率,可以做出预测。但是在训练时的成本函数该怎么设置呢?和之前的约有不同,但是不难而且很有意思。
这里的我们需要分开讨论成本函数:
-
标签为1(正类)的条件下
-又有两种如下的情况:- 0.5,也就是预测值 =1 时, 成本函数是-log(1)=0,说明预测正确
- 0.5,也就是预测值 =0 时, 成本函数是-log(0)= ,说明预测错误
-
标签为0(负类)的条件下
-也有两种情况,读者可以根据上面的自行推导。
这样的函数设置很是巧妙,不管在正类还是负类下都能直接判断预测对错。当然我们一般还有另一个形式表达这个成本函数:
这里使用 表示成本函数,和上面的一样,我还是举预测一个正类的例子,如果预测为正类,则预测值 =1,1- 为0,成本是log ,如果预测为负类,则预测值 =0,1- 为1,成本是log 。
接下来让我们在实战中更好的去体会一下逻辑回归吧。
区分维吉尼亚鸢尾花
我们的数据集来自Scikit-Learn的datasets下的iris()数据集。里面有150朵三种不同的鸢尾花(山鸢尾、变色鸢尾和维吉尼亚鸢尾)。
from sklearn import datasets
# 基本流程,这里不多废话
iris = datasets.load_iris()
print(iris["DESCR"])
# 主要看一下对于该数据集的描述的Attribute Information这个部分
:Attribute Information:
- sepal length in cm
- sepal width in cm
- petal length in cm
- petal width in cm
- class:
- Iris-Setosa
- Iris-Versicolour
- Iris-Virginica
# 上面巴拉巴拉一堆,让我们知道了,一共有四个属性,就是花萼和花瓣的长度和宽度
复制代码
咱们先来看一下什么是维吉尼亚鸢尾花,如下图2所示:
图2 维吉尼亚鸢尾花(图片来自网络)
如果有老铁看到图片标记存在问题,或者图片存在版权问题的还请告知,我会及时删除图片。
上面的标记我也是查看资料所示,不清楚标记的对不对,大致是上面标记出来的花瓣和花萼来区分鸢尾花到底属于哪一种。
from sklearn.linear_model import LogisticRegression
# 这里我们只需要判断花瓣的宽度和是否为维吉尼亚鸢尾花
X = iris["data"][:,3:]
y = (iris["target"]==2).astype(np.int)
# 训练部分
log_reg = LogisticRegression()
log_reg.fit(X, y)
# 预测部分
X_new = np.linspace(0, 3, 1000).reshape(-1, 1)
y_proba = log_reg.predict_proba(X_new)
decision_boundary = X_new[y_proba[:, 1] >= 0.5][0]
# 画图部分
plt.figure(figsize=(8, 3))
plt.plot([decision_boundary, decision_boundary], [-1, 2], "k:", linewidth=2,label="决策边界")
plt.plot(X_new, y_proba[:, 1], "g-", linewidth=2, label="维吉尼亚鸢尾")
plt.plot(X_new, y_proba[:, 0], "b--", linewidth=2, label="不是维吉尼亚鸢尾")
plt.axis([0, 3, -0.02, 1.02])
plt.xlabel("花瓣宽度")
plt.legend(loc="center left", fontsize=14)
plt.show()
复制代码
图3 维吉尼亚鸢尾花概率
这边就是简单的通过逻辑回归来训练一个分辨是否为维吉尼亚鸢尾花的二元分类器。其实我们可以看到如果要分辨出维吉尼亚鸢尾花,应该是有一个区间。但是这个区间内,不是维吉尼亚鸢尾花的概率也挺高,所以算法是不太好分辨的。于是会有设置出一个边界,只要超过这个边界直接判断属于哪一类。
我们可以看到中间就有一条决策边界,这个就是最终用于判断的,当花瓣宽度大于这个边界(大概是1.6cm)时,就判断为1,否则是0。