基于KNN的离群点检测算法的Matlab版实现

基于KNN的outlier detection是一种很简单高效的离群点检测算法,其核心思想是:通过计算对象与其k个近邻的距离之和作为离群值OF,OF值越高,越有可能是离群点。

基于KNN的outlier detection的Matlab版实现:

function [outputArg1,outputArg2] = KNN(inputArg1,inputArg2)
%KNN 此处显示有关此函数的摘要
%   此处显示详细说明
x=load('Normalization_wbc.txt');
y=load('Normalization_wbc.txt');
ADLabels=load('Label_wbc.txt');
[m,n]=size(x);
k=10;%近邻个数
Abnormal_number=20;%离群对象个数
Dist=pdist2(x,x);
SortDist=sort(Dist,2,'ascend');
Nei_k=SortDist(:,1:k+1);%因为对象离其自身的距离为0,所以再多考虑一个对象
OF=sum(Nei_k,2);
auc = Measure_AUC(OF, ADLabels);
disp(auc)
[OF_value,index_number]=sort(OF);
ODA_AbnormalObject_Number=index_number(m-Abnormal_number+1:end,:);%outlier detection algorithm 算法认定的异常对象的编号
ODA_NormalObject_Number=index_number(1:m-Abnormal_number,:);%outlier detection algorithm算法认定的正常对象的编号
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%算法实际的检测率/准确率/误报率等评价指标的计算%%%%%%%%%%%%%%%%%%%%%%%%
Label=load('Label_wbc.txt');
%%%%Real_NormalObject_Number表示数据集中真正的正常对象的编号,Real_AbnormalObject_Number表示数据集中真正异常对象的编号
[Real_NormalObject_Number,Real_Normal]=find(Label==0);
[Real_AbnormalObject_Number,Real_Abnormal]=find(Label==1);

%正例是异常对象,反例是正常对象
TP=length(intersect(Real_AbnormalObject_Number,ODA_AbnormalObject_Number));
FP=length(Real_AbnormalObject_Number)-TP;
TN=length(intersect(Real_NormalObject_Number,ODA_NormalObject_Number));
FN=length(Real_NormalObject_Number)-TN;

%准确率
ACC=(TP+TN)/(TP+TN+FP+FN);
fprintf('准确率ACC= %8.5f\n',ACC*100)
%检测率==查全率=R
DR=TP/(TP+FN);
fprintf('检测率DR= %8.5f\n',DR*100)
%查准率P
P=TP/(TP+FP);
fprintf('查准率P= %8.5f\n',P*100)
%误报率
FAR=FP/(TN+FP);
fprintf('误报率FAR= %8.5f\n',FAR*100)

%绘制混淆矩阵
Confusion_matrix=[TP,FN;FP,TN];
Figure_Confusion_matrix=heatmap(Confusion_matrix);
end

请配合计算AUC值的函数一起使用

function AccumAuc = Measure_AUC(Scores, Labels)
% Area Under Curve for Amonaly
% 
% Scores: predicted scores;
% Labels: groundtruth labels, PosLabel = 1& NegLabel = 0;

NumInst = length(Scores);

% sort Scores and Labels
[Scores, index] = sort(Scores, 'descend');
Labels = Labels(index);

PosLabel = 1;
NegLabel = 0;

NumPos = length(find(Labels == PosLabel));
NumNeg = length(find(Labels == NegLabel));

AccumPos = 0;
AccumNeg = 0;
AccumAuc = 0;

UnitPos = 1 / NumPos;
UnitNeg = 1 / NumNeg;

i = 1;
while i <= NumInst
    temp = AccumPos;
    if (i < NumInst - 1) && (Scores(i) == Scores(i + 1))
        while (i < NumInst - 1) && (Scores(i) == Scores(i + 1))
            if Labels(i) == NegLabel
                AccumNeg = AccumNeg + 1;
            elseif Labels(i) == PosLabel
                AccumPos = AccumPos + 1;
            else
                disp('Label is not defined!');
            end
            i = i + 1;
        end

        if Labels(i) == NegLabel
            AccumNeg = AccumNeg + 1;
        elseif Labels(i) == PosLabel
            AccumPos = AccumPos + 1;
        else
            disp('Label is not defined!');
        end

        AccumAuc = AccumAuc + (AccumPos + temp) * UnitPos * AccumNeg * UnitNeg / 2;
        AccumNeg = 0;
    else
        if Labels(i) == NegLabel
            AccumNeg = AccumNeg + 1;
            AccumAuc = AccumAuc + AccumPos * UnitPos * AccumNeg * UnitNeg;
            AccumNeg = 0;
        elseif Labels(i) == PosLabel
            AccumPos = AccumPos + 1;
        else
            disp('Label is not defined');
        end
    end
    i = i + 1;
end

猜你喜欢

转载自blog.csdn.net/jinhualun911/article/details/112306684