KNN实现鸢尾花分类

数据集

使用最常见的鸢尾花作为数据集,有好多人将鸢尾花的数据集上传并且付费,我一言难尽。
大家可以从我上传的资源直接下载,资源免费

将其按照7:3的比例划分为训练集和测试集。前4列为特征,第5列为类别,"setosa"视为1,"versicolor"视为2,"virginica"视为3;测试集顺序略有调整。

KNN

最简单最初级的分类器是将全部的训练数据所对应的类别都记录下来,当测试对象的属性和某个训练对象的属性完全匹配时,便可以对其进行分类。但是怎么可能所有测试对象都会找到与之完全匹配的训练对象呢,其次就是存在一个测试对象同时与多个训练对象匹配,导致一个训练对象被分到了多个类的问题,基于这些问题,就产生了KNN。

KNN是通过测量不同特征值之间的距离进行分类。它的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别,其中k通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。

其步骤如下:

  1. 计算测试数据与各个训练数据之间的距离;
  2. 按照距离的递增关系进行排序;
  3. 选取距离最小的K个点;
  4. 确定前K个点所在类别的出现频率;
  5. 返回前K个点中出现频率最高的类别作为测试数据的预测分类。

Matlab实现

基于以上步骤,可以很容易写出matlab代码,代码有详细注释:

clear all;
clc;
% 读入训练集,分类:"setosa"视为1"versicolor"视为2"virginica"视为3
train = xlsread('train.xlsx');
trainData = train(:,1:4);% 前四列是特征
trainClass = train(:,5);% 最后一列是分类

% 读入测试集
test = xlsread('test.xlsx');
testData = test(:,1:4);
N = size(testData,1);% 测试数据个数
testClass = zeros(1,N);% 定义数组存放分类结果
testClassCorrect = test(:,5)';% 正确结果,为之后判断准确率使用

% 主函数
k = 9;% 确定k的值,取值不同分类结果不同
row = size(trainData,1);% 返回行数
col = size(trainData,2);% 返回列数

% 开始for循环,对每一个测试数据进行分类
for n = 1:N
    itest = testData(n,:);% 找出一个测试集
    itestrep = repmat(itest,row,1);% 此步骤将测试数据复制为一个和训练集大小相同的矩阵,为方便后面的循环计算
    dis = zeros(1,row);% 有很多种距离计算方法,此处选择欧氏距离
    for i = 1:row
        diff = 0;
        for j = 1:col
            diff = diff + (itestrep(i,j) - trainData(i,j)).^2;
        end
        dis(1,i) = diff.^0.5;
    end
    
    % 对距离进行排序,选出距离最小的k个邻近
    Class = trainClass';
    joinClass = [dis;Class];
    sortDis= sortrows(joinClass');% 对加入分类后的矩阵按照距离升序排序
    sortDisClass = sortDis';
    
    % 统计出现次数最高的类别
    ksort = sortDisClass(2,1:k);
    table=tabulate(ksort);
    MaxPercent=max(table(:,3));
    [r,c]=find(table==MaxPercent);% 找出最大概率所在的位置
    MaxValue=table(r,1);
    testClass(1,n) = MaxValue;
end

% 显示结果
disp('最终的分类结果为:');
for x = 1:N
    fprintf('%d\t',testClass(1,x));
end
fprintf('\n');

% 计算正确率
count = 0;% 正确个数
for y = 1:N
    if (testClass(1,y)==testClassCorrect(1,y))
        count = count+1;
    end
end
disp('最终的正确率为:');
fprintf('%.2f%%\n',count/N*100);

结果如下:

最终的分类结果为:
1	1	1	1	1	2	2	2	2	2	1	1	1	1	1
2	2	2	2	2	3	3	3	3	3	2	2	2	2	2
3	3	3	2	3	1	1	1	1	1	3	3	3	3	3	
最终的正确率为:
97.78%

猜你喜欢

转载自blog.csdn.net/qq_45510888/article/details/106158889