机器学习-*-KNN最近邻分类

版权声明:原创文章未经博主允许不得转载O(-_-)O!!! https://blog.csdn.net/u013894072/article/details/83584969

算法思想通俗易懂:需要预测的数据X与历史数据做距离计算,找到距离最小的排名前K的距离点,看一下这里面哪种类型最多,就判别为X属于哪一类。
直接上代码:这里利用了TensorFlow中的MNIST手写数字数据集

#!/usr/bin/python
# -*- coding:utf-8 -*-
"""
Author LiHao
Time 2018/10/31 10:46
"""
import os
import sys
import platform
import tensorflow as tf

def __getCurrentPathAndOS__():
    """
    获取当前文件的路径及操作系统
    :return:
    """
    filename = __file__
    current_path = os.path.dirname(filename)
    os_name = platform.system()
    if os_name.lower().__contains__("win"):
        return "windows",current_path
    else:
        return "linux/mac",current_path

def load_mnist():
    MNIST_PATH = "MNIST_DATA"
    os_name,current_path = __getCurrentPathAndOS__()
    if os_name is "windows":
        current_path += '\\' + MNIST_PATH
    else:
        current_path += '/' +MNIST_PATH
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets(current_path, one_hot=True)
    return mnistt

#!/usr/bin/python
# -*- coding:utf-8 -*-
"""
Author LiHao
Time 2018/10/16 9:47
"""
import os
import sys
import numpy as np

class Algorithm(object):
    """
    算法父类---定义算法训练的流程
    """
    def predict(self):
        """
        预测
        :return:
        """
        pass
    def inputs(self):
        """
        输入
        :return:
        """
        pass
    def train(self):
        """
        训练
        :return:
        """
        pass

class KNN(Algorithm):
    """
    KNN分类算法
    """
    def __init__(self):
        self._X = np.array([])
        self._Y = np.array([])

    def _distance(self,A,B):
        """
        计算数据之间的距离
        :param A:
        :param B:
        :return:
        """
        return np.sqrt(np.sum(np.power(B-A.reshape((1,B.shape[1])),2),axis=1))

    def _sort_index(self,K,l=[]):
        """
        返回前K个相似的索引
        :param K:
        :param l:
        :return:
        """
        l_index = np.argsort(l)[:K]#argsort返回排序后的索引
        return l_index

    def transfor_to_label(self,y):
        return int(np.squeeze(np.where(y==1.0)[0]))


    def predict(self,XX,K=5,y_len=10):
        """
        :param XX: 
        :param K: 
        :param y_len: 默认是MNIST的one-hot类别 索引值代表了数字类别
        :return: 
        """
        data_num,feature_num = XX.shape
        labels = np.zeros((data_num,y_len),dtype=np.float32)
        for x in range(data_num):
            distance = self._distance(XX[x],self._X)
            dis_index = self._sort_index(K,distance)
            labelDict = {}
            predict_label = np.zeros((1,y_len),dtype=np.float32)
            for di in dis_index:
                k_label = self.transfor_to_label(self._Y[di])
                if labelDict.__contains__(k_label):
                    labelDict[k_label] += 1
                else:
                    labelDict[k_label] = 1
            predict_label[0,sorted(labelDict.items(), key=lambda x: x[1], reverse=True)[0][0]] = 1.0
            labels[x] = predict_label
        return labels

    def inputs(self,X,Y):
        self._X = X
        self._Y = Y

    def train(self,X,Y):
        self.inputs(X,Y)

def knn_train_mnist():
    knn = KNN()
    mnist = load_mnist()
    train_mnist_x = mnist.train.images
    train_mnist_y = mnist.train.labels
    test_mnist_x = mnist.test.images
    test_mnist_y = mnist.test.labels
    END = mnist.test.labels.shape[0]
    knn.train(train_mnist_x, train_mnist_y)
    all_wrong_count = 0
    for iter in np.arange(0,END,step):
        end = iter+step
        predict_labels = knn.predict(test_mnist_x[iter:end])
        predict_indexes = np.argmax(predict_labels, axis=1)
        real_indexes = np.argmax(test_mnist_y[iter:end], axis=1)
        differ_indexes = predict_indexes - real_indexes
        wrong_count = len(np.where(differ_indexes>0)[0].tolist())
        all_wrong_count += wrong_count
        wrong_rate = wrong_count*1.0/(step)
        print(iter/step," time\twrong rate is: ",wrong_rate)
    print("-*- All wrong rate is:",all_wrong_count*1.0/END)

if __name__ == '__main__':
    knn_train_mnist()

测试后的结果如下:
在这里插入图片描述

优点:简单易用
缺点:算法运行时间长,且不稳定,当选择的K进行变化时,结果可能会发生变化。K不宜取过大

猜你喜欢

转载自blog.csdn.net/u013894072/article/details/83584969