import matplotlib.pyplot as plt import numpy as np from sklearn import svm from sklearn.metrics import classification_report data=np.genfromtxt("LR-testSet2.txt",delimiter=",") #print(data) x_data=data[:,:-1] y_data=data[:,-1] #print(y_data) def plot(): x0=[] x1=[] y0=[] y1=[] #切分为不同的数据 for i in range(len(x_data)): if y_data[i]==0: x0.append(x_data[i,0]) y0.append(x_data[i,1]) else: x1.append(x_data[i, 0]) y1.append(x_data[i, 1]) #画图 #print(x0) scatter0=plt.scatter(x0,y0,c='b',marker='o') scatter1=plt.scatter(x1,y1,c='r',marker='x') #画图例 plt.legend(handles=[scatter0,scatter1],labels=['label0','label1'],loc='best') plot() plt.show() model=svm.SVC(kernel='rbf',C=2,gamma=1)#设置核函数,不同的核函数。有不同精度。C和gamma可以自己设置 model.fit(x_data,y_data) print(model.score(x_data,y_data)) #获取数据所在范围 x_min,x_max=x_data[:,0].min()-1,x_data[:,0].max()+1 y_min,y_max=x_data[:,1].min()-1,x_data[:,1].max()+1 #生成网格矩阵 xx,yy=np.meshgrid(np.arange(x_min,x_max,0.02), np.arange(y_min,y_max,0.02)) z=model.predict(np.c_[xx.ravel(),yy.ravel()]) z=z.reshape(xx.shape) #等高线图 cs=plt.contourf(xx,yy,z) plot plt.show()
svm处理非线性分类的方式。代码如下
猜你喜欢
转载自blog.csdn.net/zhuiyunzhugang/article/details/105880580
今日推荐
周排行