Logistic Regression实现iris分类

import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression 
from sklearn.metrics import accuracy_score

iris_1 = load_iris()
#print(dir(iris_1))#查看数据的属性['DESCR', 'data', 'feature_names', 'filename', 'target', 'target_names']
#print(iris_1)
#print(iris_1['data'])#(150, 4)
#print(iris_1['feature_names'])
##print(iris_1['filename'])#保存的文件名
#print(iris_1['target'])
##print(iris_1['target'].shape)
#print(iris_1['target_names'])
##print(iris_1.DESCR)#查看数据详细的属性
#
#
#setosa=plt.scatter(X.loc[:,"Sepal length"][Y==0],
#				   X.loc[:,"Sepal width"][Y==0], 
#				   color='red', marker='o', label='setosa') 
#
#versicolor=plt.scatter(X.loc[:,"Sepal length"][Y==1],
#					   X.loc[:,"Sepal width"][Y==1], 
#					   color='blue', marker='*', label='versicolor') 
#					   
#virginica=plt.scatter(X.loc[:,"Sepal length"][Y==2],
#					  X.loc[:,"Sepal width"][Y==2], 
#					  color='yellow', marker='x', label='virginica') 
#
#plt.legend((setosa,versicolor,virginica),
#		     ('setosa','versicolor','virginica'))
#plt.xlabel("Sepal length")   
#plt.ylabel("Sepal width")
#plt.show()


#####["Sepal length","petal length"] 
X = iris_1['data'][:,[0,2]]
Y = iris_1['target']
X=pd.DataFrame(X)
Y=pd.DataFrame(Y)
X.columns=["Sepal length","petal length"] 
Y.columns=["target"] 
#print(X.head())
#print(Y.head())
Y=Y.loc[:,"target"]

lr = LogisticRegression(C=1e5)  
lr = lr.fit(X,Y)
y_pred= lr.predict(X)
print(accuracy_score(Y,y_pred))

#由函数关系在散点图上绘制出三条边界线,直观的看一下分类效果
#c+a*x1+b*x2=0
fig=plt.figure()
a1=lr.coef_[0][0]
b1=lr.coef_[0][1]
c1=lr.intercept_[0]

a2=lr.coef_[1][0]
b2=lr.coef_[1][1]
c2=lr.intercept_[1]

a3=lr.coef_[2][0]
b3=lr.coef_[2][1]
c3=lr.intercept_[2]
x=X.loc[:,"Sepal length"]

y1=-(c1+a1*x)/b1
y2=-(c2+a2*x)/b2
y3=-(c3+a3*x)/b3
plt.figure(figsize=(15,8),dpi=80)
plt.plot(x,y1)
#plt.plot(x,y2)
plt.plot(x,y3)

setosa=plt.scatter(X.loc[:,"Sepal length"][Y==0],
				   X.loc[:,"petal length"][Y==0], 
				   color='red', marker='o', label='setosa') 

versicolor=plt.scatter(X.loc[:,"Sepal length"][Y==1],
					   X.loc[:,"petal length"][Y==1], 
					   color='blue', marker='*', label='versicolor') 
					   
virginica=plt.scatter(X.loc[:,"Sepal length"][Y==2],
					  X.loc[:,"petal length"][Y==2], 
					  color='yellow', marker='x', label='virginica') 

plt.legend((setosa,versicolor,virginica),          
               ('setosa','versicolor','virginica'))
plt.xlabel("Sepal length")
plt.ylabel("petal length")
plt.show()

猜你喜欢

转载自blog.csdn.net/ziqingnian/article/details/108350997