Binary classification of data with svm (full code)

I. Introduction

This is a small example of the author learning svm in opencv. The data set is generated by the function in the sk-learn library. The function is to classify the data into two categories.
In order to learn more deeply, the author compared the four commonly used cores of svm, linear, inter, sigmoid, and rbf.
See the code for detailed steps

Two, the code

import numpy as np
from cv2 import cv2
from sklearn import datasets
import matplotlib.pyplot as plt
from sklearn import model_selection as ms
from sklearn import metrics


def plot_decision_boundary(svm,x_text,y_text):
    '''
    函数功能,将svm训练出的结果可视化函数,参数说明,
    svm:训练好的svm模型
    x_text:测试数据集
    y_text:测试数据集标签
    '''
    #  确定x轴上的左右边界
    x_min,x_max = x_text[:,0].min()-1,x_text[:,0].max()+1
    #  确定y轴上的左右边界
    y_min,y_max = x_text[:,1].min()-1,x_text[:,1].max()+1
    #  创建合适的网格,h是步长
    h =0.02  
    xx,yy = np.meshgrid(np.arange(x_min,x_max,h),
                        np.arange(y_min,y_max,h))
    #  np.c_功能,按行连接两个矩阵,要求行数一样。
    x_hypo = np.c_[xx.ravel().astype(np.float32),
                   yy.ravel().astype(np.float32)]
    #  开始预测
    _,zz = svm.predict(x_hypo)
    #  大小要一样
    zz = zz.reshape(xx.shape)
    #  绘制三维等高线图,必须在网格结构中才可以
    plt.contourf(xx,yy,zz,cmap=plt.cm.coolwarm,alpha=0.8)
    #  按标签绘制散点图
    plt.scatter(x_text[:,0],x_text[:,1],c=y_text,s=100)




if __name__ == '__main__':
    #  使用自带函数创建一个含有100个样本,特征两个,两个标签的数据样本  
    x,y = datasets.make_classification(n_samples=100,n_features=2,n_classes=2,n_redundant=0,random_state=7816)
    x = x.astype(np.float32)
    y = y*2-1
#  定义一个含各种svm核的列表
    kernels = [cv2.ml.SVM_LINEAR,cv2.ml.SVM_INTER,
           cv2.ml.SVM_SIGMOID,cv2.ml.SVM_RBF]
    #  分割数据集,20%为测试集
    x_train,x_text,y_train,y_text = ms.train_test_split(x,y,test_size=0.2,random_state=42)
    #  训练含有不同核的svm分类器
    for i,kernel in enumerate(kernels):
        svm = cv2.ml.SVM_create()
        #  设置svm核
        svm.setKernel(kernel)
        svm.train(x_train,cv2.ml.ROW_SAMPLE,y_train)
        b,y_pred = svm.predict(x_text)
        #  用测试集计算准确率
        a=metrics.accuracy_score(y_text,y_pred)
        print(a)
        #  创建4个子图
        plt.subplot(2,2,i+1)
        #  可视化结果
        plot_decision_boundary(svm,x_text,y_text)
        plt.title('accuracy: %.2f' %a)
    plt.show()

3. Results

insert image description here
It can be seen from the figure that the rbf kernel has the highest accuracy rate, and the rbf kernel is more suitable for nonlinear classification

Guess you like

Origin blog.csdn.net/bookshu6/article/details/111748678