Python实现支持向量机(SVM) MNIST数据集

Python实现支持向量机(SVM) MNIST数据集

SVM的原理这里不讲,大家自己可以查阅相关资料。

下面是利用sklearn库进行svm训练MNIST数据集,准确率可以达到90%以上。


from sklearn import svm
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
train_num = 10000
test_num = 1000

x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

# 获取一个支持向量机模型
predictor = svm.SVC(gamma='scale', C=1.0, decision_function_shape='ovr', kernel='rbf')
# 把数据丢进去
predictor.fit(x_train[:train_num], y_train[:train_num])
# 预测结果
result = predictor.predict(x_test[:test_num])
# 准确率估计
accurancy = np.sum(np.equal(result, y_test[:test_num])) / test_num
print(accurancy)

SVC函数的参数解析

gamma

支持向量机的间隔,即是超平面距离不同类别的最小距离,是一个float类型的值,可以自己规定,也可以用SVM自己的值,有两个选择。

  • auto 选择auto时,gamma = 1/feature_num ,也就是特征的数目分之1
  • scale 选择scale时,gamma = 1/(feature_num * X.std()), 特征数目乘样本标准差分之1. 一般来说,scael比auto结果准确。
C

看到过SVM公式推导的同学对C一定不陌生,它是松弛变量的系数,称为惩罚系数,用来调整容忍松弛度,当C越大,说明该模型对分类错误更加容忍,也就是为了避免过拟合。

decision_function_shape

两个选择

  • ovr one vs rest 将一个类别与其他所有类别进行划分
  • ovo one vs one 两两划分

kernel

核函数的选择

  • 当样本线性可分时,一般选择linear 线性核函数
  • 当样本线性不可分时,有很多选择,这里选择rbf 即径向基函数,又称高斯核函数。

猜你喜欢

转载自blog.csdn.net/qq_35170267/article/details/84290367
今日推荐