支持向量机SVM理解篇

Posted
by 子颢 on August 14,2018
#算法原理
支持向量机(Support Vector Machine,SVM)是机器学习中的最经典也是最重要的分类方法之一。



样本空间中任一点x到超平面的距离为:


现在我们希望求解上式来得到最大间隔超平面所对应的模型:f(x) = w * x + b




下面还是通过一个具体例子感受一下线性可分支持向量机的训练过程。


核函数(kernel trick),我们在线性回归等几个小节中曾经提到过核函数的概念,polynomial也是核函数的一种。






那么我们在实际应用当中到底应该怎样选择核函数呢?告诉大家一条铁律:首先选择线性核(LinearSVC),如果训练集不太大,再试一下RBF核。 只有一个对称函数所对应的核矩阵半正定,它才能作为核函数使用(亦即才能拆成映射函数的内积)。


为了解决这个问题,可以对每个样本点引进一个对应的松弛变量,用以表征样本不满足约束的程度,使函数间隔加上松弛变量大于等于1。这样,约束条件变为:


我们从损失函数的角度看,gamma表示样本不满足约束的程度,如果样本满足约束,那么gamma值为0。所以这实际上是hinge损失:hinge(z) = max(0, 1-z)。上面7.31式可以改写为:

加号后的一项就是SVM的hinge损失函数,加号前的一项恰好是L2正则。如果我们将上式中的损失函数变为对数损失,那么恰好变成了加了L2正则的逻辑回归。

既然讲到了这里,那我们不防继续深入一下,试着从损失函数的角度探讨SVM和LR各自的特点是什么?
1.    因为LR和SVM的优化目标接近(损失函数渐进趋同),所以通常情况下他们的表现也相当。
2.    SVM的hinge损失函数在z大于1后,都是平坦的0区域,这使得SVM的解具有稀疏性(只与支持向量有关,函数图像拐点位置);而LR的log损失是光滑的单调递减函数,不能导出类似支持向量的概念。因此LR的求解过程依赖于所有样本点,开销更大(尤其是需要用到核函数时)。
3.    SVM和LR都是使用一个最优分隔超平面进行样本点的划分,且距离分隔超平面越远的点对模型的训练影响越小。SVM是完全无影响(平坦的0区域),LR是影响较弱(损失函数渐进趋于0)。
4.    因为SVM的训练只与支持向量点有关,所以数据unbalance对SVM几乎无影响,而LR一般需要做样本均衡处理。
5.    LR回归的输出具有自然的概率含义,SVM的输出是样本点到最优超平面的距离,欲得到概率需要进行特殊处理。
我们依然通过拉格朗日乘子法求解加入松弛变量的SVM:



支持向量回归(SVR):找到两条平行直线带,带内点的损失为0,带上的点是尽可能多的支持向量。


训练方法与SVM相同。
#模型训练
代码地址 https://github.com/qianshuang/ml-exp
def train():
   print("start training...")
   # 处理训练数据
   # train_feature, train_target = process_file(train_dir, word_to_id, cat_to_id)  # 词频特征
   train_feature, train_target = process_tfidf_file(train_dir, word_to_id, cat_to_id)  # TF-IDF特征
   # 模型训练
   model.fit(train_feature, train_target)
def test():
   print("start testing...")
   # 处理测试数据
   test_feature, test_target = process_file(test_dir, word_to_id, cat_to_id)
   # test_predict = model.predict(test_feature)  # 返回预测类别
   test_predict_proba = model.predict_proba(test_feature)    # 返回属于各个类别的概率
   test_predict = np.argmax(test_predict_proba, 1)  # 返回概率最大的类别标签
   # accuracy
   true_false = (test_predict == test_target)
   accuracy = np.count_nonzero(true_false) / float(len(test_target))
   print()
   print("accuracy is %f" % accuracy)
   # precision    recall  f1-score
   print()
   print(metrics.classification_report(test_target, test_predict, target_names=categories))
   # 混淆矩阵
   print("Confusion Matrix...")
   print(metrics.confusion_matrix(test_target, test_predict))
if not os.path.exists(vocab_dir):
   # 构建词典表
   build_vocab(train_dir, vocab_dir)
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_dir)
# kNN
# model = neighbors.KNeighborsClassifier()
# decision tree
# model = tree.DecisionTreeClassifier()
# random forest
# model = ensemble.RandomForestClassifier(n_estimators=10)  # n_estimators为基决策树的数量,一般越大效果越好直至趋于收敛
# AdaBoost
# model = ensemble.AdaBoostClassifier(learning_rate=1.0)  # learning_rate的作用是收缩基学习器的权重贡献值
# GBDT
# model = ensemble.GradientBoostingClassifier(n_estimators=10)
# xgboost
# model = xgboost.XGBClassifier(n_estimators=10)
# Naive Bayes
model = naive_bayes.MultinomialNB()
# logistic regression
# model = linear_model.LogisticRegression()   # ovr
# model = linear_model.LogisticRegression(multi_class="multinomial", solver="lbfgs")  # softmax回归
# SVM
model = svm.LinearSVC()  # 线性,无概率结果
model = svm.SVC(probability=True)  # 核函数,训练慢
train()
test()
运行结果:
read_category...
read_vocab...
start training...
start testing...
accuracy is 0.970000
            precision    recall  f1-score   support
        游戏       1.00      1.00      1.00       104
        时政       0.92      0.93      0.92        94
        体育       1.00      0.99      1.00       116
        娱乐       0.99      0.99      0.99        89
        时尚       1.00      0.99      0.99        91
        教育       0.97      0.94      0.96       104
        家居       0.91      0.96      0.93        89
        财经       0.96      0.96      0.96       115
        科技       1.00      0.99      0.99        94
        房产       0.94      0.96      0.95       104
avg / total       0.97      0.97      0.97      1000
Confusion Matrix...
[[104   0   0   0   0   0   0   0   0   0]
[  0  87   0   0   0   0   1   3   0   3]
[  0   1 115   0   0   0   0   0   0   0]
[  0   1   0  88   0   0   0   0   0   0]
[  0   0   0   0  90   1   0   0   0   0]
[  0   1   0   1   0  98   3   0   0   1]
[  0   1   0   0   0   2  85   1   0   0]
[  0   1   0   0   0   0   2 110   0   2]
[  0   0   0   0   0   0   1   0  93   0]
[  0   3   0   0   0   0   1   0   0 100]]

社群
QQ交流群

微信公众号
了解更多干货文章,可以关注小程序八斗问答

猜你喜欢

转载自blog.csdn.net/gaoyan0335/article/details/86139139