深入浅出KNN算法及其Java实现

1.KNN算法

KNN算法是最简单,也是最常用的分类算法,是有监督学习中的分类算法。

1.1.算法简述

KNN的全称是K Nearest Neighbors,意思是K个最近的邻居,从这个名字我们就能看出:K的取值肯定是至关重要的。

KNN的原理就是当预测一个新的值x的时候,根据它距离最近的K个点是什么类别来判断x属于哪个类别

听起来有点绕,还是举例说明:

KNN原理1

图中绿色的点就是我们要预测的那个点,假设K=3。那么KNN算法就会找到与它距离最近的三个点(这里用圆圈把它圈起来了),看看哪种类别多一些,比如这个例子中是蓝色三角形多一些,新来的绿色点就归类到蓝三角了。

KNN原理2

但是,当K=5的时候,判定就变成不一样了。这次变成红圆多一些,所以新来的绿点被归类成红圆。从这个例子中,我们就能看得出K的取值是很重要的。

对KNN算法来说,最重要的主要有两方面:K值的选取点距离的计算

1.2.距离计算

KNN算法中使用的是欧式距离,拿二维平面为例,二维空间两个点的欧式距离计算公式如下:

二维空间欧式距离

其实就是计算(x1,y1)和(x2,y2)的距离。

拓展到多维空间,则公式变成这样:

多维空间欧式距离

KNN算法最简单粗暴的就是将预测点与所有点距离进行计算,然后保存并排序,选出前面K个值看看哪些类别比较多。

1.3.K值选择

通过1.1的例子我们知道K的取值比较重要,那么该如何确定K取多少值好呢?

答案是通过交叉验证(将样本数据按照一定比例,拆分出训练用的数据和验证用的数据,比如6:4拆分出部分训练数据和验证数据),从选取一个较小的K值开始,不断增加K的值,然后计算验证集合的方差,最终找到一个比较合适的K值。

一般而言,通过交叉验证计算方差后你大致会得到下面这样的图:

K值与Error

这个图其实很好理解,当你增大k的时候,一般错误率会先降低,因为有周围更多的样本可以借鉴了,分类效果会变好。但注意,当K值更大的时候,错误率会更高。这也很好理解,比如说你一共就35个样本,当你K增大到30的时候,KNN基本上就没意义了。

所以选择K点的时候可以选择一个较大的临界K点,当它继续增大或减小的时候,错误率都会上升,比如图中的K=10。

1.4.KNN算法的优势和劣势

1.4.1.优点

  1. 简单易用,相比其他算法,KNN算是比较简洁明了的算法。即使没有很高的数学基础也能搞清楚它的原理。
  2. 模型训练时间快。
  3. 预测效果好。
  4. 对异常值不敏感

1.4.2.缺点

  1. 对内存要求较高,因为该算法存储了所有训练数据
  2. 预测阶段可能很慢
  3. 对不相关的功能和数据规模敏感

2.Java实现

2.1.代码

import java.util.*;

public class Knn {
    /**
     *
     * @Title: cal
     * @Description:
     * @param centerPoint 未知中心点
     * @param data 数据集
     * @param k 邻域值
     */
    public static Map<String,Object > cal(Double[] centerPoint, List<Double[]> data, int k, int feature) {
        //计算所有已知点到未知点的欧式距离
        List<Map<String, Object>> tmpList = new ArrayList<Map<String, Object>>();
        for (int i = 0; i < data.size(); i++) {
            Double[] bArr = data.get(i);
            Double euclidDis = euclidDistance(centerPoint, bArr,feature);
            if (euclidDis==0)  continue;
            Map<String, Object> map = new HashMap<String, Object>();
            map.put("dis", euclidDis);
            map.put("index", i);
            map.put("type", bArr[feature]);
            tmpList.add(map);
        }

        //根据距离对所有已知点排序
        Collections.sort(tmpList, new Comparator<Map<String, Object>>() {
            public int compare(Map<String, Object> f1, Map<String, Object> f2) {
                double d1 = (double) f1.get("dis");
                double d2 = (double) f2.get("dis");
                if (d1 > d2) {
                    return 1;
                } else if (d1 < d2) {
                    return -1;
                } else {
                    return 0;
                }
            }
        });

        //选取最近的k个点
        List<Map<String, Object>> tmpListSub = tmpList.subList(0, k);

        // 计算每个分类包含的点的个数
        Map<String,Integer> classify = new HashMap<>();
        for (Map<String, Object> map : tmpListSub) {
            String type = map.get("type")+"";
            classify.merge(type,1,Integer::sum);
        }
        // 找出最大频率
        double value = 0.0;
        String type = "";
        for (Map.Entry<String, Integer> entry : classify.entrySet()) {
            if (entry.getValue() > value) {
                type = entry.getKey();
                value = entry.getValue();
            }
        }
        Map<String,Object> result = new HashMap<>();
        result.put("knn_point",tmpListSub);
        result.put("type",type);
        return result;
    }

    /**
     *
     * @Title: euclidDistance
     * @Description: 欧氏距离
     * @param a
     * @param b
     * @return Double
     */
    public static Double euclidDistance(Double[] a, Double[] b, int feature){
        double result = 0;
        double tmp = 0;
        for (int i = 0; i < feature ; i++) {
            tmp = tmp + Math.pow(a[i]-b[i],2);
        }
        result = Math.sqrt(tmp);
        return result;
    }

    public static void excute(String trainPath,String testPath, int k, int feature) {
        try {
            List<Double[]> trains = FileUtils.readDoubles(trainPath, "\t");
            List<Double[]> tests  = FileUtils.readDoubles(testPath, "\t");
            DrawPic drawPic = new DrawPic();
            for (Double[] test : tests) {
                Map<String, Object> classifyResult = cal(test,trains,k,feature);
                Double type = Double.parseDouble(classifyResult.get("type")+"");
                test[feature] = type+2;
                System.out.println(classifyResult.toString());
            }
            drawPic.add(trains, 0d);
			drawPic.add(trains, 1d);
            drawPic.add(tests, 2d);
            drawPic.add(tests, 3d);
			drawPic.draw("KNN TEST RESULT:");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] args) {
        int k = 5; // 参数一,k近邻
        int feature = 2; // 参数二,特征的个数
        excute("test/trainKnn.txt", "test/testKnn.txt", k , feature);
    }
}

2.2.训练集:

0.559159703 0.098352957    0
0.781938472    0.351149046    0
0.322619685    0.68735181 1
0.424119019    0.800061286    1
0.963396034    0.507747552    0
0.025621522    0.421294488    0
0.256461082    0.42557099 0
0.18307667 0.292654328    0
0.639442992    0.327507571    0
0.11249193 0.415212654    0
0.089114599    0.857211652    1
0.249131294    0.135416934    0
0.860781951    0.186191506    0
0.595163818    0.116916255    0
0.271419895    0.449691344    0
0.43646567 0.93251139 1
0.015846183    0.261730053    0
0.774858773    0.808469061    0
0.742867778    0.733759061    0
0.74295491 0.992062927    0
0.853975626    0.977766081    0
0.569375171    0.200101637    0
0.223755566    0.408569886    0
0.62065295 0.699191534    0
0.56423325 0.903351997    1
0.93734134 0.22314747 0
0.956509155    0.352316866    0
0.488649148    0.687423546    0
0.553580938    0.081469598    0
0.218311577    0.278594184    0
0.435339431    0.645957677    0
0.369294227    0.80281829 1
0.052275118    0.55456295 1
0.650081762    0.291811967    0
0.801814551    0.130852795    0
0.862452749    0.950437897    0
0.689082808    0.251889264    0
0.09853475 0.412371885    0
0.289120272    0.068057336    0
0.135108188    0.651060347    1
0.55931731 0.714674629    0
0.386982056    0.761612126    1
0.508901062    0.552195358    0
0.085845048    0.519086874    1
0.748453007    0.259367763    0
0.023389989    0.360986147    0
0.403579896    0.053640904    0
0.573968451    0.625821009    0
0.904036937    0.494519806    0
0.357751222    0.038605127    0
0.289068703    0.996306106    1
0.324216807    0.247300728    0
0.84744147 0.282028386    0
0.174259437    0.10879232 0
0.65723266 0.162442422    0
0.647686639    0.134837519    0
0.438747109    0.391559812    0
0.212046463    0.446433524    0
0.923317003    0.612253639    0
0.922838542    0.479557344    0
0.930696766    0.741726617    0
0.952616342    0.406805062    0
0.38932093 0.359289985    0
0.263773034    0.197158645    0
0.052055233    0.58522216 1
0.661426754    0.176925664    0
0.577817595    0.885369811    1
0.410437881    0.350911589    0
0.25447879 0.12295319 0
0.69820201 0.938979225    0
0.708159093    0.097565246    0
0.170464713    0.781592102    1
0.596360062    0.69804899 0
0.044356167    0.436906786    0
0.833448417    0.733349389    0
0.312559922    0.677245526    1
0.515330652    0.788754853    1
0.306931908    0.368021736    0
0.120495932    0.134803063    0
0.195167977    0.715243725    1
0.588992255    0.129598264    0
0.660493659    0.97095288 1
0.683932542    0.904360467    0
0.62728259 0.219561202    0
0.93361751 0.795767678    0
0.611341157    0.183558451    0
0.767462911    0.223386048    0
0.196077886    0.174656028    0
0.168137091    0.949702934    1
0.486978259    0.612132859    0
0.301523984    0.859036882    1
0.859118552    0.028799836    0
0.856611283    0.286111664    0
0.954981271    0.54964407 0
0.766554927    0.925246588    0
0.988669688    0.826772066    0
0.673578898    0.288661704    0
0.817091353    0.591044774    0
0.789128146    0.559132763    0
0.886453397    0.014503077    0
0.240633555    0.010275127    0
0.347477609    0.59875547 0
0.723260725    0.793767401    0
0.933630534    0.194448974    0
0.210915933    0.11436068 0
0.050058141    0.769411281    1
0.784958505    0.413175444    0
0.241895   0.568624123    1
0.673796133    0.633554186    0
0.3767971  0.01176521 0
0.65704224 0.053507921    0
0.304597084    0.370059785    0
0.683021664    0.636328542    0
0.779393545    0.293803541    0
0.620436623    0.365085781    0
0.849167085    0.572128924    0
0.467997506    0.756505709    1
0.581376199    0.602764684    0
0.001670169    0.98407763 1
0.882013039    0.801522851    0
0.653121628    0.251655431    0
0.55796115 0.518879925    0
0.596073047    0.69744254 0
0.031032908    0.299830479    0
0.072822929    0.862113671    1
0.9224477  3.28E-04   0
0.218801332    0.776612516    1
0.882055463    0.738065117    0
0.812817791    0.780229273    0
0.145641594    0.7987032  1
0.740037534    0.835949736    0
0.925142296    0.49013454 0
0.727447795    0.040233445    0
0.792203525    0.175172446    0
0.421066491    0.065706972    0
0.792427042    0.836907325    0
0.12603555 0.308458655    0
0.486083138    0.265620163    0
0.113550239    0.237603472    0
0.578618525    0.058056105    0
0.491027782    0.24279796 0
0.044068206    0.196272695    0
0.188194963    0.876208421    1
0.59427145 0.227052602    0
0.709731721    0.549117564    0
0.114626924    0.152160306    0
0.488955732    0.866743922    1
0.658656048    0.785276781    0
0.208374543    0.06776244 0
0.735047094    0.105306509    0
0.018805699    0.699531941    1
0.790984119    0.580100981    0
0.841962559    0.215964196    0
0.993733368    0.678689475    0
0.084575646    0.715900447    1
0.562700931    0.552142048    0
0.744381188    0.809232821    0
0.735690924    0.179272592    0
0.842038774    0.483817611    0
0.858868018    0.968614631    0
0.758880331    0.978954864    0
0.061272032    0.884061512    1
0.120795761    0.308392785    0
0.861511168    0.55556587 0
0.37500318 0.523331719    0
0.985894258    0.158554434    0
0.561795575    0.618856707    0
0.0285602  0.876746258    1
0.366750794    0.085376453    0
0.488624589    0.514958198    0
0.313223701    0.582710728    0
0.143654389    0.344754984    0
0.412079142    0.661913686    0
0.50899188 0.127682629    0
0.056833548    0.982670712    1
0.438638683    0.915229256    1
0.101613959    0.815242435    1
0.990950574    0.423789218    0
0.406947674    0.321068723    0
0.728123969    0.279372892    0
0.762987341    0.458203973    0
0.954334935    0.217322227    0
0.361464321    0.846310764    1
0.778972641    0.095845214    0
0.43760825 0.436413274    0
0.017785997    0.48031178 0
0.164948023    0.472827548    0
0.666601327    0.084190442    0
0.235169627    0.026138001    0
0.634839006    0.707822041    0
0.979204192    0.961152799    0
0.6913863  0.028671717    0
0.196003941    0.216732298    0
0.397462012    0.175089452    0
0.578839539    0.032593579    0
0.890593425    0.909356154    0
0.48917651 0.311922377    0
0.777101256    0.631189695    0
0.40518121 0.918050457    1
0.158785075    0.727937457    1
0.389705704    0.05235087 0
0.490394828    0.946309155    1
0.774812563    0.293589598    0
0.356909841    0.669024069    1
0.891840695    0.869408375    0
0.571528976    0.79408773 0
0.459172231    0.827375092    1
0.118353158    0.539058174    1
0.539115373    0.699284255    0
0.012149967    0.185307544    0
0.271625184    0.808299236    1
0.309702694    0.05801239 0
0.735241331    0.139080759    0
0.265153681    0.341914159    0
0.016672641    0.775696187    1
0.120524733    0.003039323    0
0.88965971 0.845131836    0
0.334308136    0.72586824 1
0.254865219    0.077870714    0
0.571129118    0.206600035    0
0.27516719 0.023768075    0
0.918960966    0.821465102    0
0.855880282    0.835429484    0
0.106598832    0.581344709    1
0.453630222    0.974979567    1
0.166576055    0.728092768    1
0.504998198    0.866662439    1
0.214741409    0.875432398    1
0.37814347 0.695683246    1
0.066920806    0.79775841 1
0.283672719    0.935663685    1
0.027358126    0.731799978    1
0.550928494    0.898764484    1
0.060676932    0.717959698    1
0.163306527    0.926933468    1
0.024396637    0.988913373    1
0.197800921    0.970003859    1
0.178661291    0.79708392 1
0.386520536    0.949415833    1
0.435466533    0.722697974    1
0.053350725    0.569663353    1
0.292447402    0.664881407    1
0.509224634    0.976647829    1
0.222992322    0.881366985    1
0.247824961    0.706386584    1
0.082784155    0.954941134    1
0.342328962    0.626215951    1
0.273804754    0.720389805    1
0.331594477    0.945419326    1
0.426170805    0.739214852    1
0.096871736    0.715779384    1
0.128907263    0.730137703    1
0.665925074    0.977890165    1
0.066070264    0.977908886    1
0.109256724    0.849794811    1
0.036701381    0.732651833    1
0.558859425    0.873114233    1
0.21952731 0.686461173    1
0.26794263 0.647051569    1
0.010089968    0.635421137    1
0.057019109    0.700209208    1
0.402620087    0.719909979    1
0.240416309    0.668431956    1
0.539578702    0.900015605    1
0.101637325    0.92766932 1
0.295172013    0.900548476    1
0.358548944    0.92749591 1
0.198487527    0.562198301    1
0.368583536    0.942652127    1
0.031636113    0.871328941    1
0.183909513    0.958116981    1
0.247770988    0.96926925 1
0.277184324    0.747748294    1
0.429529408    0.768752855    1
0.190304499    0.753790139    1
0.018189242    0.831992234    1
0.32890642 0.633875357    1
0.658258842    0.938963185    1
0.251038211    0.652845006    1
0.093920028    0.806054176    1
0.181317893    0.664991991    1
0.248188847    0.912275228    1
0.238839691    0.858678587    1
0.23117987 0.61764752 1
0.105243287    0.60234405 1
0.387166308    0.984512662    1
0.338112572    0.640637109    1
0.208831634    0.966547831    1
0.433571431    0.707981672    1
0.197061929    0.684306376    1
0.468532867    0.879279234    1
0.479906353    0.858351595    1
0.175381995    0.974871752    1
0.29651653 0.890605848    1
0.044588102    0.921254447    1
0.158678506    0.641707   1
0.608506183    0.934815632    1
0.108912757    0.741997915    1
0.112778699    0.758458386    1
0.650237524    0.977904349    1
0.225452235    0.994745743    1
0.32585813 0.820238557    1
0.071192275    0.706700706    1
0.211623369    0.57700678 1
0.080363048    0.608812715    1
0.069701571    0.56315952 1
0.158791014    0.750694597    1
0.199008247    0.969003189    1
0.30536649 0.695472553    1
0.029767776    0.963198515    1
0.158903774    0.708398213    1
0.064347166    0.903413621    1
0.380673595    0.986659914    1
0.035822485    0.873671297    1
0.236370823    0.685758662    1
0.108768375    0.830937071    1
0.185502771    0.70172469 1
0.284665615    0.90656859 1
0.097703054    0.966858102    1
0.111640399    0.636321974    1
0.321916548    0.695808099    1
0.433719757    0.710247603    1
0.659811098    0.958967393    1
0.54616972 0.979870303    1
0.101174754    0.928101789    1
0.54675508 0.812950702    1
0.238457822    0.833707629    1
0.259079919    0.74382291 1
0.468939545    0.878297214    1
0.063752611    0.791115422    1
0.451180706    0.777552159    1
0.615368285    0.951344541    1
0.466177954    0.886348385    1
0.43100112 0.902341329    1
0.299937316    0.66623742 1
0.240543312    0.589234033    1
0.605841085    0.874745491    1
0.511797692    0.920812104    1
0.480809987    0.882446856    1
0.101015819    0.593120953    1
0.157909966    0.751703819    1
0.035707792    0.84279398 1
0.179329884    0.565170806    1
0.056522575    0.918872345    1
0.54357903 0.938701118    1
0.225314744    0.992520345    1
0.078614722    0.753872004    1
0.642566854    0.966021515    1
0.093130905    0.879236185    1
0.102617285    0.768702192    1
0.361220713    0.865849089    1
0.406597742    0.863961868    1
0.10440507 0.688093108    1
0.411637157    0.798814931    1
0.575299042    0.878158869    1
0.475463924    0.990202916    1
0.340432746    0.939165645    1
0.246234203    0.707766799    1
0.328560776    0.683569936    1
0.550646781    0.956722592    1
0.024359445    0.557174209    1
0.598933162    0.901057044    1
0.0706504  0.578609654    1
0.160536278    0.944209111    1
0.224303387    0.97965662 1
0.28300118 0.660712504    1
0.372145157    0.756368999    1
0.248521112    0.820065825    1
0.37907606 0.788056062    1
0.446980501    0.909372129    1
0.213295909    0.995305476    1
0.388695001    0.75934272 1
0.260641363    0.810867931    1
0.489522049    0.830323783    1
0.107112854    0.851134269    1
0.048777929    0.994415964    1
0.516192628    0.899072342    1
0.27749659 0.771527235    1
0.14368047 0.733693082    1
0.088902534    0.756926048    1
0.561850419    0.852981461    1
0.112498495    0.852516795    1
0.32249627 0.630457536    1
0.55497113 0.828490068    1
0.032283253    0.938501487    1
0.147883264    0.675961561    1
0.171231541    0.642979677    1
0.330671791    0.983155598    1
0.069693762    0.985000457    1
0.596641628    0.927682208    1
0.094448045    0.981347068    1
0.233576118    0.587407127    1
0.380471393    0.863043035    1
0.244991556    0.831641311    1
0.476147527    0.874005371    1
0.011621024    0.842144471    1
0.353009932    0.659625741    1

 2.3.测试集:

0.117347857 0.664476028    1
0.294670399    0.621357066    1
0.282788426    0.643479382    1
0.174460982    0.883150328    1
0.074223377    0.762623324    1
0.265698816    0.639501657    1
0.025308361    0.814070821    1
0.145211188    0.548882643    1
0.315994193    0.965642323    1
0.160564179    0.427239984    0
0.319108869    0.116817648    0
0.829953328    0.77263275 0
0.682534864    0.653759154    0
0.129320255    0.351237019    0
0.78742235 0.197266104    0
0.885035805    0.771164421    0

 2.4.测试结果:

1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0

测试训练集点分布图(红色为训练集0,蓝色为训练集1,绿色为测试集0,黄色为测试集1)

猜你喜欢

转载自blog.csdn.net/u012998680/article/details/120568649
今日推荐