《机器学习与数据挖掘》实验六

实验题目:   用线性核与高斯核训练支持向量机                                             

实验目的:   掌握支持向量机的原理及应用                                       

实验环境(硬件和软件)   Anaconda/Jupyter notebook/Pycharm                               

实验内容:

使用Sklearn,在西瓜集3.0α上分别使用线性核和高斯核训练一个SVM,并比较其支持向量的差别。

要求:

一、经给定部分代码,补充完整的代码,需要补充代码的地方已经用红色字体标注,包括:

1#补充构建SVM模型及训练代码

2#补充预测代码

3#补充得到支持向量代码

二、将补充完整的代码提交,并提交实验结果;(也可以自己重写这部分的代码提交

data_file_watermelon_3a = "watermelon_3a.csv"
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import svm

df = pd.read_csv(data_file_watermelon_3a, header=None, )
df.columns = ['id', 'density', 'sugar_content', 'label']
df.set_index(['id'])
X = df[['density', 'sugar_content']].values
y = df['label'].values

from sklearn import svm


for fig_num, kernel in enumerate(('linear', 'rbf')):
    # 补充构建SVM模型及训练代码
    # initial
    svc = svm.SVC(C=1000, kernel=kernel)  # classifier 1 based on linear kernel
    # train
    svc.fit(X, y)
    # 给定新的样本X_test,预测其标签
    X_test = [[0.719, 0.103]]
    # 补充预测代码
    y_train_pred = svc.predict(X_test)
    print('The predction result:', y_train_pred)
    # get support vectors
    sv = svc.support_vectors_
    ##### draw decision zone
    plt.figure(fig_num)
    plt.clf()

    # plot point and mark out support vectors
    plt.scatter(X[:, 0], X[:, 1], edgecolors='k', c=y, cmap=plt.cm.Paired, zorder=10)
    plt.scatter(sv[:, 0], sv[:, 1], edgecolors='k', facecolors='none', s=80, linewidths=2, zorder=10)
    # plot the decision boundary and decision zone into a color plot
    x_min, x_max = X[:, 0].min() - 0.2, X[:, 0].max() + 0.2
    y_min, y_max = X[:, 1].min() - 0.2, X[:, 1].max() + 0.2
    XX, YY = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))
    Z = svc.decision_function(np.c_[XX.ravel(), YY.ravel()])
    Z = Z.reshape(XX.shape)
    plt.pcolormesh(XX, YY, Z > 0, cmap=plt.cm.Paired)
    plt.contour(XX, YY, Z, colors=['k', 'k', 'k'], linestyles=['--', '-', '--'], levels=[-.5, 0, .5])

    plt.title(kernel)
    plt.axis('tight')
    plt.show()

 实验截图

 

猜你喜欢

转载自blog.csdn.net/m0_64351669/article/details/128199793