svm处理非线性分类的方式。代码如下

import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm
from sklearn.metrics import classification_report
data=np.genfromtxt("LR-testSet2.txt",delimiter=",")
#print(data)
x_data=data[:,:-1]

y_data=data[:,-1]
#print(y_data)
def plot():
    x0=[]
    x1=[]
    y0=[]
    y1=[]
    #切分为不同的数据
    for i in range(len(x_data)):
        if y_data[i]==0:
            x0.append(x_data[i,0])
            y0.append(x_data[i,1])
        else:
            x1.append(x_data[i, 0])
            y1.append(x_data[i, 1])
    #画图
    #print(x0)
    scatter0=plt.scatter(x0,y0,c='b',marker='o')
    scatter1=plt.scatter(x1,y1,c='r',marker='x')
    #画图例
    plt.legend(handles=[scatter0,scatter1],labels=['label0','label1'],loc='best')
plot()
plt.show()
model=svm.SVC(kernel='rbf',C=2,gamma=1)#设置核函数,不同的核函数。有不同精度。C和gamma可以自己设置
model.fit(x_data,y_data)
print(model.score(x_data,y_data))
#获取数据所在范围
x_min,x_max=x_data[:,0].min()-1,x_data[:,0].max()+1
y_min,y_max=x_data[:,1].min()-1,x_data[:,1].max()+1
#生成网格矩阵
xx,yy=np.meshgrid(np.arange(x_min,x_max,0.02),
                  np.arange(y_min,y_max,0.02))
z=model.predict(np.c_[xx.ravel(),yy.ravel()])
z=z.reshape(xx.shape)
#等高线图
cs=plt.contourf(xx,yy,z)
plot
plt.show()

猜你喜欢

转载自blog.csdn.net/zhuiyunzhugang/article/details/105880580