版权声明:添加我的微信wlagooble,开启一段不一样的旅程 https://blog.csdn.net/nineship/article/details/88354310
code:
#coding:utf-8
import numpy as np
from matplotlib import pyplot as plt
import random
def KNN(X_train,y_train,X_test,k):
X_train=np.array(X_train)
y_train = np.array(y_train)
X_test = np.array(X_test)
X_train_len=len(X_train)
X_test_len=len(X_test)
pre_lable=[]
for test_len in range(X_test_len):
dis=[]
#caculate the distance
for train_len in range(X_train_len):
#print ('train_len',X_train[train_len,:])
#print ('test_len', X_test[test_len,:])
temp_dis = 0
temp_dis+=sum(abs(X_train[train_len,:]-X_test[test_len,:]))
dis.append(temp_dis**0.5)
#print (temp_dis)
dis=np.array(dis)
sort_id=dis.argsort()
#print ('sort_id',sort_id)
dic={}
for i in range(k):
vlable=y_train[sort_id[i]]
#print ("vlable",vlable)
dic[vlable]=dic.get(vlable,0)+1
#print ('dic[vlable]',dic[vlable])
max1 = 0
#print ("dic",dic)
for index, v in dic.items():
if v > max1:
max1 = v
maxIndex = index
#print ("maxIndex",maxIndex)
pre_lable.append(maxIndex)
x_show = X_test[:,0]
y_show = X_test[:,1]
color = '#FF0000'
for i in range(len(pre_lable)):
if pre_lable[i] == 1:
color = '#00FF00'
elif pre_lable[i] == 2:
color = '#00FFCC'
elif pre_lable[i] == 3:
color = '#FF00CC'
else:
color = '#6622CC'
plt.scatter(x_show[i],y_show[i],color=color)
print (pre_lable)
plt.ylabel('Y axis')
plt.xlabel('X axis')
plt.show()
if __name__=="__main__":
X_train = [
[45, 45, 1],
[40, 50, 1],
[60, 70, 1],
[60, 20, 1],
[-45, 45, 1],
[-50, 50, 1],
[-70, 20, 1],
[-30, 70, 1],
[-45, -45, 1],
[-50, -50, 1],
[-30, -60, 1],
[50, -20, 1],
[10, -70, 1],
[30, -50, 1]
]
y_train = [1,1,1,1, 2,2,2,2 , 3,3,3, 4,4,4]
X_test = []
X_test = [[50,-20,1]]
for i in range(500):
temp = []
temp.append(random.randint(-100,100))
temp.append(random.randint(-100,100))
temp.append(1)
X_test.append(temp)
KNN(X_train,y_train,X_test,1)
~
显示结果如下: