knn用于水果数据集分类

数据集地址:https://download.csdn.net/download/fanzonghao/10940440 

knn算法流程:

若k取无穷大,那么测试数据就取决于每一类的占比,归属于占比最大的那一类。

首先观察数据集,利用mass,height,width,color_score四列特征进行水果分类。

g=sns.pairplot(data=fruits_df,hue='fruit_name',vars=['mass','width','height','color_score'])

然后利用sns.pairplot查看两两特征之间的关系,可看出对角线是每一类的直方图,mass和width几乎呈线性关系。

再利用width,height,color_score,建立三维图,看出绿色可以容易区分,对于更高维的数据可以采用pca降维然后进行查看。

knn代码:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import seaborn as sns
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import ml_visualization
#利用k近邻法分离
#分离训练集与测试集
from sklearn.model_selection import train_test_split
fruits_df=pd.read_table('fruit_data_with_colors.txt')
print(fruits_df)
print('样本个数:',len(fruits_df))
#创建目标标签和名称的  字典
fruits_name_dict=dict(zip(fruits_df['fruit_label'],fruits_df['fruit_name']))
#print(fruits_df['fruit_label'])
print(fruits_name_dict)
#划分数据集
X=fruits_df[['mass','width','height','color_score']]
# print(X)
y=fruits_df['fruit_label']
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=1/4,random_state=0)
print('X_train=\n',X_train)
print('y_train=\n',y_train)
print('数据集样本数:{},训练集样本数:{},测试集样本数:{}'.format(len(X),len(X_train),len(X_test)))
#
#可视化查看特征变量,对角线就是直方图,其余是两两直接的关系
g=sns.pairplot(data=fruits_df,hue='fruit_name',vars=['mass','width','height','color_score'])
plt.savefig('1.jpg')
#三维查看
label_color_dict={1:'red',2:'green',3:'blue',4:'yellow'}
colors=list(map(lambda label: label_color_dict[label],y_train))
print('colors=\n',colors)
fig=plt.figure()
ax=fig.add_subplot(111,projection='3d')
ax.scatter(X_train['width'],X_train['height'],X_train['color_score'],c=colors,marker='o',s=100)
ax.set_xlabel('width')
ax.set_ylabel('height')
ax.set_zlabel('color_score')
plt.show()
# #建立knn模型
acc_scores=[]
for k in range(1,20):
    knn=KNeighborsClassifier(n_neighbors=k)
#训练模型
    knn.fit(X_train,y_train)
#预测
    y_pred=knn.predict(X_test)
# print('y_pred=',y_pred)
# print('y_test=\n',y_test)
    acc=accuracy_score(y_test,y_pred)
    acc_scores.append(acc)
plt.figure()
plt.xlabel('k')
plt.ylabel('accuarcy')
plt.plot(acc_scores,marker='o')
plt.show()
print('准确率:',acc)
ml_visualization.plot_fruit_knn(X_train,y_train,5)

可视化代码:ml_visualization.py

# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import neighbors
import graphviz
from sklearn.tree import export_graphviz
import matplotlib.patches as mpatches

def plot_fruit_knn(X, y, n_neighbors):
    """
        在“水果数据集”上对 height 和 width 二维数据进行kNN训练
        并绘制出结果
    """
    X_mat = X[['height', 'width']].as_matrix()
    y_mat = y.as_matrix()

    # Create color maps
    cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF', '#AFAFAF'])
    cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF', '#AFAFAF'])

    clf = neighbors.KNeighborsClassifier(n_neighbors)
    clf.fit(X_mat, y_mat)

    # Plot the decision boundary by assigning a color in the color map
    # to each mesh point.
    
    mesh_step_size = .01  # step size in the mesh
    plot_symbol_size = 50
    
    x_min, x_max = X_mat[:, 0].min() - 1, X_mat[:, 0].max() + 1
    y_min, y_max = X_mat[:, 1].min() - 1, X_mat[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, mesh_step_size),
                         np.arange(y_min, y_max, mesh_step_size))
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

    # Put the result into a color plot
    Z = Z.reshape(xx.shape)
    plt.figure()
    plt.pcolormesh(xx, yy, Z, cmap=cmap_light)

    # Plot training points
    plt.scatter(X_mat[:, 0], X_mat[:, 1], s=plot_symbol_size, c=y, cmap=cmap_bold,
                edgecolor='black')
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())

    patch0 = mpatches.Patch(color='#FF0000', label='apple')
    patch1 = mpatches.Patch(color='#00FF00', label='mandarin')
    patch2 = mpatches.Patch(color='#0000FF', label='orange')
    patch3 = mpatches.Patch(color='#AFAFAF', label='lemon')
    plt.legend(handles=[patch0, patch1, patch2, patch3])

    plt.xlabel('height (cm)')
    plt.ylabel('width (cm)')
    
    plt.show()

结果:可看出k为5时acc最高。

而对于回归的話,对于k==3,相邻的三个值取平均,也可以利用距离加权。

猜你喜欢

转载自blog.csdn.net/fanzonghao/article/details/86411102
今日推荐