基于KNN的outlier detection是一种很简单高效的离群点检测算法,其核心思想是:通过计算对象与其k个近邻的距离之和作为离群值OF,OF值越高,越有可能是离群点。
基于KNN的outlier detection的Matlab版实现:
function [outputArg1,outputArg2] = KNN(inputArg1,inputArg2)
%KNN 此处显示有关此函数的摘要
% 此处显示详细说明
x=load('Normalization_wbc.txt');
y=load('Normalization_wbc.txt');
ADLabels=load('Label_wbc.txt');
[m,n]=size(x);
k=10;%近邻个数
Abnormal_number=20;%离群对象个数
Dist=pdist2(x,x);
SortDist=sort(Dist,2,'ascend');
Nei_k=SortDist(:,1:k+1);%因为对象离其自身的距离为0,所以再多考虑一个对象
OF=sum(Nei_k,2);
auc = Measure_AUC(OF, ADLabels);
disp(auc)
[OF_value,index_number]=sort(OF);
ODA_AbnormalObject_Number=index_number(m-Abnormal_number+1:end,:);%outlier detection algorithm 算法认定的异常对象的编号
ODA_NormalObject_Number=index_number(1:m-Abnormal_number,:);%outlier detection algorithm算法认定的正常对象的编号
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%算法实际的检测率/准确率/误报率等评价指标的计算%%%%%%%%%%%%%%%%%%%%%%%%
Label=load('Label_wbc.txt');
%%%%Real_NormalObject_Number表示数据集中真正的正常对象的编号,Real_AbnormalObject_Number表示数据集中真正异常对象的编号
[Real_NormalObject_Number,Real_Normal]=find(Label==0);
[Real_AbnormalObject_Number,Real_Abnormal]=find(Label==1);
%正例是异常对象,反例是正常对象
TP=length(intersect(Real_AbnormalObject_Number,ODA_AbnormalObject_Number));
FP=length(Real_AbnormalObject_Number)-TP;
TN=length(intersect(Real_NormalObject_Number,ODA_NormalObject_Number));
FN=length(Real_NormalObject_Number)-TN;
%准确率
ACC=(TP+TN)/(TP+TN+FP+FN);
fprintf('准确率ACC= %8.5f\n',ACC*100)
%检测率==查全率=R
DR=TP/(TP+FN);
fprintf('检测率DR= %8.5f\n',DR*100)
%查准率P
P=TP/(TP+FP);
fprintf('查准率P= %8.5f\n',P*100)
%误报率
FAR=FP/(TN+FP);
fprintf('误报率FAR= %8.5f\n',FAR*100)
%绘制混淆矩阵
Confusion_matrix=[TP,FN;FP,TN];
Figure_Confusion_matrix=heatmap(Confusion_matrix);
end
请配合计算AUC值的函数一起使用
function AccumAuc = Measure_AUC(Scores, Labels)
% Area Under Curve for Amonaly
%
% Scores: predicted scores;
% Labels: groundtruth labels, PosLabel = 1& NegLabel = 0;
NumInst = length(Scores);
% sort Scores and Labels
[Scores, index] = sort(Scores, 'descend');
Labels = Labels(index);
PosLabel = 1;
NegLabel = 0;
NumPos = length(find(Labels == PosLabel));
NumNeg = length(find(Labels == NegLabel));
AccumPos = 0;
AccumNeg = 0;
AccumAuc = 0;
UnitPos = 1 / NumPos;
UnitNeg = 1 / NumNeg;
i = 1;
while i <= NumInst
temp = AccumPos;
if (i < NumInst - 1) && (Scores(i) == Scores(i + 1))
while (i < NumInst - 1) && (Scores(i) == Scores(i + 1))
if Labels(i) == NegLabel
AccumNeg = AccumNeg + 1;
elseif Labels(i) == PosLabel
AccumPos = AccumPos + 1;
else
disp('Label is not defined!');
end
i = i + 1;
end
if Labels(i) == NegLabel
AccumNeg = AccumNeg + 1;
elseif Labels(i) == PosLabel
AccumPos = AccumPos + 1;
else
disp('Label is not defined!');
end
AccumAuc = AccumAuc + (AccumPos + temp) * UnitPos * AccumNeg * UnitNeg / 2;
AccumNeg = 0;
else
if Labels(i) == NegLabel
AccumNeg = AccumNeg + 1;
AccumAuc = AccumAuc + AccumPos * UnitPos * AccumNeg * UnitNeg;
AccumNeg = 0;
elseif Labels(i) == PosLabel
AccumPos = AccumPos + 1;
else
disp('Label is not defined');
end
end
i = i + 1;
end