版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/weixin_44474718/article/details/86681634
引自百度:邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。
kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。
k-NN的输出结果会随着K的变化而变化。
import numpy as np
import cv2
import matplotlib.pyplot as plt
plt.style.use('ggplot')
#生成训练数据集
np.random.seed(7) #设置随机数种子
def generate_data(num_samples,num_features=2): #创建数据
data_size=(num_samples,num_features) #num_samples行, num_features列
data=np.random.randint(0,100,size=data_size)
labels_size=(num_samples,1)
labels=np.random.randint(0,2,size=labels_size)
return data.astype(np.float32),labels
def plot_data(all_blue,all_red):
plt.scatter(all_blue[:,0],all_blue[:,1],c='b',marker='s',s=180) # 蓝色的所有x轴坐标和y轴坐标
plt.scatter(all_red[:, 0],all_red[:,1], c='r', marker='^', s=180)
plt.xlabel('x')
plt.ylabel('y')
train_data ,labels=generate_data(11)
blue = train_data[labels.ravel() == 0]
red = train_data[labels.ravel() == 1]
plot_data(blue,red)
#训练分类器
knn=cv2.ml.KNearest_create() #创建分类器
knn.train(train_data,cv2.ml.ROW_SAMPLE,labels) #knn的数组必须是N*2的数组 传入训练数据
#预测新数据的类别
newcomer, _ =generate_data(1) #下划线_ 表示python忽略该输出值
plt.plot(newcomer[0,0],newcomer[0,1],'go',markersize=14) #画出需要预测值的图标
plt.show()
ret,result,neighbor,dist=knn.findNearest(newcomer,3)
# print("预测标签:%s ,邻近标签:%s,距离:%s" %(result,neighbor,dist))
print("Predicted label:\t", result)
print("Neighbors' labels:\t", neighbor)
print("Distance to neighbors:\t", dist)