版权声明:原创文章未经博主允许不得转载O(-_-)O!!! https://blog.csdn.net/u013894072/article/details/83584969
算法思想通俗易懂:需要预测的数据X与历史数据做距离计算,找到距离最小的排名前K的距离点,看一下这里面哪种类型最多,就判别为X属于哪一类。
直接上代码:这里利用了TensorFlow中的MNIST手写数字数据集
#!/usr/bin/python
# -*- coding:utf-8 -*-
"""
Author LiHao
Time 2018/10/31 10:46
"""
import os
import sys
import platform
import tensorflow as tf
def __getCurrentPathAndOS__():
"""
获取当前文件的路径及操作系统
:return:
"""
filename = __file__
current_path = os.path.dirname(filename)
os_name = platform.system()
if os_name.lower().__contains__("win"):
return "windows",current_path
else:
return "linux/mac",current_path
def load_mnist():
MNIST_PATH = "MNIST_DATA"
os_name,current_path = __getCurrentPathAndOS__()
if os_name is "windows":
current_path += '\\' + MNIST_PATH
else:
current_path += '/' +MNIST_PATH
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(current_path, one_hot=True)
return mnistt
#!/usr/bin/python
# -*- coding:utf-8 -*-
"""
Author LiHao
Time 2018/10/16 9:47
"""
import os
import sys
import numpy as np
class Algorithm(object):
"""
算法父类---定义算法训练的流程
"""
def predict(self):
"""
预测
:return:
"""
pass
def inputs(self):
"""
输入
:return:
"""
pass
def train(self):
"""
训练
:return:
"""
pass
class KNN(Algorithm):
"""
KNN分类算法
"""
def __init__(self):
self._X = np.array([])
self._Y = np.array([])
def _distance(self,A,B):
"""
计算数据之间的距离
:param A:
:param B:
:return:
"""
return np.sqrt(np.sum(np.power(B-A.reshape((1,B.shape[1])),2),axis=1))
def _sort_index(self,K,l=[]):
"""
返回前K个相似的索引
:param K:
:param l:
:return:
"""
l_index = np.argsort(l)[:K]#argsort返回排序后的索引
return l_index
def transfor_to_label(self,y):
return int(np.squeeze(np.where(y==1.0)[0]))
def predict(self,XX,K=5,y_len=10):
"""
:param XX:
:param K:
:param y_len: 默认是MNIST的one-hot类别 索引值代表了数字类别
:return:
"""
data_num,feature_num = XX.shape
labels = np.zeros((data_num,y_len),dtype=np.float32)
for x in range(data_num):
distance = self._distance(XX[x],self._X)
dis_index = self._sort_index(K,distance)
labelDict = {}
predict_label = np.zeros((1,y_len),dtype=np.float32)
for di in dis_index:
k_label = self.transfor_to_label(self._Y[di])
if labelDict.__contains__(k_label):
labelDict[k_label] += 1
else:
labelDict[k_label] = 1
predict_label[0,sorted(labelDict.items(), key=lambda x: x[1], reverse=True)[0][0]] = 1.0
labels[x] = predict_label
return labels
def inputs(self,X,Y):
self._X = X
self._Y = Y
def train(self,X,Y):
self.inputs(X,Y)
def knn_train_mnist():
knn = KNN()
mnist = load_mnist()
train_mnist_x = mnist.train.images
train_mnist_y = mnist.train.labels
test_mnist_x = mnist.test.images
test_mnist_y = mnist.test.labels
END = mnist.test.labels.shape[0]
knn.train(train_mnist_x, train_mnist_y)
all_wrong_count = 0
for iter in np.arange(0,END,step):
end = iter+step
predict_labels = knn.predict(test_mnist_x[iter:end])
predict_indexes = np.argmax(predict_labels, axis=1)
real_indexes = np.argmax(test_mnist_y[iter:end], axis=1)
differ_indexes = predict_indexes - real_indexes
wrong_count = len(np.where(differ_indexes>0)[0].tolist())
all_wrong_count += wrong_count
wrong_rate = wrong_count*1.0/(step)
print(iter/step," time\twrong rate is: ",wrong_rate)
print("-*- All wrong rate is:",all_wrong_count*1.0/END)
if __name__ == '__main__':
knn_train_mnist()
测试后的结果如下:
优点:简单易用
缺点:算法运行时间长,且不稳定,当选择的K进行变化时,结果可能会发生变化。K不宜取过大