Getting k-nearest neighbor algorithm theory - Machine Learning

//2019.08.01 afternoon
machine learning algorithms 1 - k nearest neighbor
1, k-nearest neighbor learning machine learning algorithm is the most classic and simple algorithm , it is one of the best machine learning algorithms entry algorithm, it can be very good and fast understanding of the framework and application of machine learning algorithms.
2, kNN machine learning algorithm has the following characteristics:
(1) ideological extremely simple
(2) mathematical knowledge very few applications
(3) the effect of solving related problems is very good
(4) can be explained by a machine learning algorithm used in the process a lot of details question
(5) a more complete characterization of machine learning application process

The principle is as follows: determining a new attribute categories at all points on the basis of the original data set, the value of k is specified, then find k nearest point of all the raw input data points need to determine its new point, and then to determine the properties of the new points according to the k attribute classification points.

 

FIG original data points 1

 

2 position of the new input point of the distribution of FIG, 3 is designated k, i.e., find the closest three points

4、KNN算法原理介绍及其训练学习代码实现:
import numpy as np
import matplotlib.pyplot as plt #导入相应的数据可视化模块
raw_data_X=[[3.393533211,2.331273381],
[3.110073483,1.781539638],
[1.343808831,3.368360954],
[3.582294042,4.679179110],
[2.280362439,2.866990263],
[7.423436942,4.696522875],
[5.745051997,3.533989803],
[9.172168622,2.511101045],
[7.792783481,3.424088941],
[7.939820817,0.791637231]
]
raw_data_Y=[0,0,0,0,0,1,1,1,1,1]
print(raw_data_X)
print(raw_data_Y)
x_train=np.array(raw_data_X)
y_train=np.array(raw_data_Y)     #数据的预处理,需要将其先转换为矩阵,并且作为训练数据集
print(x_train)
print(y_train)
plt.figure(1)
plt.scatter(x_train[y_train==0,1],x_train[y_train==0,0],color="g")
plt.scatter(x_train[y_train==1,0],x_train[y_train==1,1],color="r") #将其散点图输出
x=np.array([8.093607318,3.365731514]) #定义一个新的点,需要判断它到底属于哪一类数据类型
plt.scatter(x[0],x[1],color="b") #在算点图上输出这个散点,看它在整体散点图的分布情况
#kNN机器算法的使用
from math import sqrt
distance=[]
for x_train in x_train:
d=sqrt(np.sum((x_train-x)**2))
distance.append(d)
print(distance)
d1=np.argsort(distance) #输出distance排序的索引值
print(d1)
k=6
n_k=[y_train[(d1[i])] for i in range(0,k)]
print(n_k)
from collections import Counter #导入Counter模块
c=Counter(n_k).most_common(1)[0][0] #Counter模块用来输出一个列表中元素的个数,输出的形式为列表,其里面的元素为不同的元组
#另外的话对于Counter模块它有.most_common(x)可以输出统计数字出现最多的前x个元组,其中元组的key是其元素值,后面的值是出现次数
y_predict=c
print(y_predict)
plt.show() #输出点的个数

实现代码及其结果如下:

Guess you like

Origin www.cnblogs.com/Yanjy-OnlyOne/p/11283454.html