邻近分类算法---KNN(Java实现+公式计算例子)

名词解释:

KNN:K-Nearest-Neighbor K值最邻近
所以网上有好几种叫法:最近邻,邻近…其实都是指的KNN,其实按照翻译都对。KNN算法产生于1968年,是数据挖掘和数据分类技术中最简单的入门级算法。
按照字面意思,就是计算出距离K值最近的邻居数据,设对应数据为X,然后把待分类数据归类为X。

原理

基于大量不同维度的训练数据,做循环比较分析,具体经历下面四步流程:
①准备数据,对数据进行预处理 。
②计算测试样本点(也就是待分类点)到其他每个样本点的距离 。
③对每个距离进行排序,然后选择出距离最小的K个点。
④对K个点所属的类别进行比较,根据少数服从多数的原则,将测试样本点归入在K个点中占比最高的那一类
更通俗点来讲:
类似孟母三迁,物以类聚人以群分,如果待分类数据的各项数值偏向于一定范围K内的最多的一种类型X,那么大胆预估该数据就属于类型X

公式

既然该算法的核心就是计算距离,所以它有下面几个公式

(1)欧氏距离

也称欧几里得距离,是最常见的距离度量,衡量的是多维空间中两个点之间的绝对距离。可以把一个点想象成原点,计算目标点到原点的直线距离,这是高中时候最容易理解的知识.
在这里插入图片描述

(上图摘自:https://blog.csdn.net/bluesliuf/article/details/88862918)

曼哈顿距离

又称出租车距离,跟欧式直线距离不同,想象你打出租车去曼哈顿区,肯定不能走直线,路上有围墙,房屋,公园,你要绕,左左右右上上下下,东西南北的走,而这段曲折的非直线的距离,就是曼哈顿距离
在这里插入图片描述

(上图摘自:https://blog.csdn.net/WangTaoTao_/article/details/102973124)
更清晰了解曼哈顿距离和欧式距离的区别,可以看百度的这张图
在这里插入图片描述

切比雪夫距离

在上面两种距离中,欧式距离两个斜方向,有点无视‘障碍物’,力争最短距离的感觉,而曼哈顿距离只有上下左右(东南西北)四个距离,这里介绍的切比雪夫距离,可以想象成是欧式和曼哈顿的结合体,一种‘超人’,既可以循规蹈矩上下左右,也可以斜着无视障碍物走,这就是切比雪夫距离(类似国际象棋中的国王)
在这里插入图片描述

(上图摘自:https://blog.csdn.net/WangTaoTao_/article/details/102973124)

马氏距离

在上面介绍的三种距离公式,有一个共同的缺陷,就是当多维度数据分析中,量词单位不一致导致的差距过大或过小,例如:
小明身高173cm,体重50000g
小黄身高162cm,体重50000g
设:小强身高175cm,体重60000g
求:判断小强体型与谁的类型更相近

如果按照上面的距离公式,由于体重的量词是g,所以数值偏大,而身高是cm,甚至m,这个重要指标的数值反而偏小,这就是网上说的“受量纲关系影响”,导致归类识别结果是:
小强体型与小黄相近

大错特错,一个175,一个162,差了一大截,预测结果还说体型相近,这个时候就要用到马氏距离公式

标准差:

是一组数值自平均值分散开来的程度的一种测量观念。一个较大的标准差,代表大部分的数值和其平均值之间差异较大;一个较小的标准差,代表这些数值较接近平均值。
公式意义
所有数(个数为n)记为一个数组[n]。将数组的所有数求和后除以n得到算术平均值。数组的所有数分别减去平均值,得到的n个差值分别取平方,再将得到的所有平方数求和,然后除以数的个数或个数减一(若所求为总体标准差则除以n,若所求为样本标准差则除以(n-1)),最后把得到的商取算术平方根,就是取1/2次方,得到的结果就是这组数(n个数据)的标准差。

方差:

方差是标准差的平方,而标准差的意义是数据集中各个点到均值点距离的平均值。反应的是数据的离散程度。

协方差:

在这里插入图片描述

协方差矩阵:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

马氏距离计算例子:

在这里插入图片描述
以上就是一个马氏距离的例子,具体详细的计算步骤我也写出来了(网上很多大神觉得太简单了没列),用代码实现的话就不难了

代码实现

假设有业务需求:使用现有的一批训练数据,使用邻近分类法帮新用户进行分类

(注意这里的距离计算公式我只是简单的求了距离绝对值,严格来说要按照业务需求判断使用上述的哪种距离公式,贴合具体场景使用,不可一味照搬)

package com.example.map.math;

import org.apache.catalina.LifecycleState;

import java.util.*;

/**
 * @author jojo
 * 近邻分类器
 */
public class KNN {
    
    

    // 数据模型
    public static class KnnModel implements Comparable<KnnModel>{
    
    
        public double paramA;
        public double paramB;
        public double paramC;
        public double distance;
        String type;
        public double getDistance() {
    
    
            return distance;
        }

        public KnnModel(double a, double b, double c, String type) {
    
    
            this.paramA = a;
            this.paramB = b;
            this.paramC = c;
            this.type = type;
        }

        @Override
        public int compareTo(KnnModel model) {
    
    
            return Double.compare(this.distance, model.distance);
        }
    }

    /** 计算新数据与训练数据的距离 **/
    private static List<KnnModel> calculate(List<KnnModel> modelList, KnnModel model, int k) {
    
    
        for (KnnModel m : modelList) {
    
    
            double distanceA = Math.abs(model.paramA - m.paramA);
            double distanceB = Math.abs(model.paramB - m.paramB);
            double distanceC = Math.abs(model.paramC - m.paramC);
            double gap = distanceA + distanceB + distanceC;
            // 训练数据保存与目标数据的距离,方便下一步排序
            m.distance = gap;
        }
        // 根据distance大小进行排序(从小到大)
        Collections.sort(modelList, Comparator.comparingDouble(KnnModel::getDistance));
        // 返回差距最小的k个值
        List<KnnModel> resultList = new ArrayList<>(k);
        for (int i = 0; i < k; i++) {
    
    
            resultList.add(modelList.get(i));
        }
        return resultList;
    }

    /** 统计出最多的类型 **/
    private static String findTypeByScope(List<KnnModel> modelList) {
    
    
        Map<String, Integer> typeMap = new HashMap<>(modelList.size());
        // 统计类型
        for (KnnModel model : modelList) {
    
    
            int sum = typeMap.get(model.type) == null ? 1 : typeMap.get(model.type) + 1;
            typeMap.put(model.type, sum);
        }
        // 返回出现次数最多的类型
        List<Map.Entry<String,Integer>> list = new ArrayList(typeMap.entrySet());
        Collections.sort(list, Comparator.comparingInt(Map.Entry::getValue));
        return list.get(list.size()-1).getKey();
    }

    /** Knn
     * @param modelList 训练数据集
     * @param model 待分类数据
     * @param k 范围变量
     * */
    public static String calculateKnn(List<KnnModel> modelList, KnnModel model, int k) {
    
    
        // (1) 计算训练数据与待分类数据的各自相对距离,并返回差距最小的K个训练结果
        List<KnnModel> minKnnList = calculate(modelList, model, k);
        // (2) 找出差距最小的K个结果中,最多的类型
        return findTypeByScope(minKnnList);
    }

    public static void main(String[] args) {
    
    
        // 准备数据(假设参数A为身高,B为体重,C为颜值,用类型分为'帅哥''普通''屌丝')
        List<KnnModel> knnModelList = new ArrayList<>();
        knnModelList.add(new KnnModel(178, 75, 88, "帅哥"));
        knnModelList.add(new KnnModel(180, 73, 96, "帅哥"));
        knnModelList.add(new KnnModel(183, 80, 95, "帅哥"));
        knnModelList.add(new KnnModel(173, 75, 95, "普通"));
        knnModelList.add(new KnnModel(170, 72, 80, "普通"));
        knnModelList.add(new KnnModel(171, 71, 89, "普通"));
        knnModelList.add(new KnnModel(155, 70, 60, "屌丝"));
        knnModelList.add(new KnnModel(159, 80, 68, "屌丝"));
        knnModelList.add(new KnnModel(160, 75, 70, "屌丝"));
        // 预测数据
        KnnModel model = new KnnModel(176.5, 70, 92, null);
        // 输出预测类型
        System.out.println(calculateKnn(knnModelList, model, 3));
    }

}

执行结果:
在这里插入图片描述

注意⚠️,这个结果不严谨,只是简单的演示大概的代码,严格来说要按照业务需求判断使用上述的哪种距离公式,贴合具体场景使用,不可一味照搬

K值的选择

综上所述,我们可以看到K值的取值,直观影响我们的分类结果,所以K值的范围也是大有讲究,甚至这么一个小小的参数,也是需要大量数据分析训练得出来的
K值的误差要根据实际数据和训练数据做判断,限制在合理范围,可以用类似指数平滑之类的预测算法计算K值的最小误差,其余的暂不展开研究,待日后补充
Java手写三次指数平滑算法

(上面是本人做另外一个需求做的三次指数平滑,如果对精确度要求没那么高,一二次的指数平滑也能确定K值)

优点

(1)简单易懂

缺点

(1)计算量大
(2)庞大的训练样本数据才能提高准确性
(3)准确率较低

总结:

以上就是个人的理解和总结,如果有错误的地方欢迎指正一起学习~

猜你喜欢

转载自blog.csdn.net/whiteBearClimb/article/details/123007830