原理:略。
步骤:
二分类问题:
(1)将第一类样本作为正样本,第二类样本作为负样本。首先,对样本的向量空间进行增广,即对n维向量x的首部或者尾部增加一个参数1,增广为(n+1)维向量,并对其进行规范化,即正样本不做处理,负样本的(n+1)维向量取负。
(2)定义一个(n+1)维权向量w,并进行初始化,定义学习步长LearnRate。
(3)进行迭代,对于每个样本,如果w与x的转置的乘积大于0,则不做处理,否则更新权向量:
w=w+LearnRate*x
直到对所有样本的w与x的转置的乘积大于0,退出迭代。
(4)最终得到的w即最终的权向量,得到直线w1+w2*x1+w3*x3+....=0(增广的参数1在(n+1)维向量首部时)。
多分类问题:
(1)同样对所有样本进行增广,但不用进行规范化。
(2)定义k个(n+1)维权向量,k为类别数,并进行初始化,定义学习步长LearnRate。
(3)迭代,如果第i类样本j存在wi*xj'<=wt*xj',其中t为非i类,则进行如下操作:
wi=wi+LearnRate*xj
wt=wt-LearnRate*xj
直到所有wi*xj'>wt*xj',退出迭代。
(4)得到k组权向量,wi-wk=0为第i类和k类样本的分界线。
二维多分类问题代码:
clear clc n=6;%样本点个数 class=4;%类别数 pattern=[1 2 -1 0 -1 2;1 1 -1 -1 1 -1]; %每一列为一个样本 Class=[1 1 2 2 3 4]; %类别 LearnRate=0.2; PlotPats(pattern,Class-1); %绘制样本点 input=[ones(1,n);pattern]'; %每一行为一个规范化后的样本 w=zeros(class,3); for i=1:30 %迭代次数 break_time=0; %当值为0的时候表明本次迭代没有更新,退出迭代 for j=1:n j_class=Class(j); answer=zeros(1,class); for k=1:class %对于每一个样本对于所有类计算w*x' answer(k)=w(k,:)*(input(j,:))'; end ret=0; for l=1:class %如果某j_class类样本的answer小于该样本与其他类的anwser ret记为1 if(l~=j_class)&&(answer(j_class)<=answer(l)) ret=1; end end if ret==1 break_time=break_time+1; for m=1:class%同时改变所有权值 if m~=j_class w(m,:)=w(m,:)-LearnRate*input(j,:); else w(m,:)=w(m,:)+LearnRate*input(j,:); end end end end if break_time==0 break end end newclass=class*(class-1)/2; new=zeros(newclass,3); t=0; for i=1:class%计算每个分界线 for j=i+1:class t=t+1; new(t,:)=w(i,:)-w(j,:) end end for q=1:newclass PlotBoundary([new(q,:)] ,i,1)%绘制分界线 end drawnow
其中绘制样本点与绘制分界线的函数PlotPats.m与PlotBoundary.m如下:
function PlotPats(P,D) % PLOTPATS Plots the training patterns defined by Patterns and Desired. % % P - NELTS x NPATS matrix of input patterns (column vectors). % The first two values in each pattern are used % as the coordinates of the point to be plotted. % % D - NUNITS x NPATS matrix of desired binary output patterns. % The first 2 bits of the output pattern determine the % class of the point: o, +, *, or x. [NELTS,NPATS] = size(P); NUNITS = size(D,1); if NUNITS<2, D=[D;zeros(1,NPATS)]; end colordef none clf reset, whitebg(gcf,[0.82 0.82 0.82]) hold on, box on % Calculate the bounds for the plot and cause axes to be drawn. xmin = min(P(1,:)); xmax = max(P(1,:)); xb = (xmax-xmin)*0.2; ymin = min(P(2,:)); ymax = max(P(2,:)); yb = (ymax-ymin)*0.2; axis([xmin-xb, xmax+xb, ymin-yb, ymax+yb]); title('Input Classification'); xlabel('x1'); ylabel('x2'); class = 1 + D(1,:) + 2*D(2,:); colors = [1 0 1; 1 1 0; 0 1 1; 0 1 0]; symbols = 'o+*x'; for i=1:NPATS c = class(i); plot(P(1,i),P(2,i),symbols(c),'Color',colors(c,:),'LineWidth',3); end
function PlotBoundary(W,iter,done) colors = jet; if ~done lstyle = '--'; color = colors(1+rem(3*iter+9,size(colors,1)),:); else lstyle = '-'; color = [1 1 1]; end d = W(3); if abs(d) < 0.001, d = 0.001; end plot([-2 2],(-W(2)*[-2 2]-W(1))/d,'LineStyle',lstyle,'Color',color,'LineWidth',2); drawnow
运行结果:
三维二分类问题代码:
pattern=[5 2 3 12 30 14;7 3 4 10 12 18;8 5 6 10 36 14]; Desired=[0 0 0 1 1 1]; PlotPats3D(pattern,Desired); [m n]=size(Desired); w = [0 0 0 0]; input=[ones(1,n);pattern]'; for i=1:n if Desired(i)==1 input(i,:)=-input(i,:); end end learnrate=0.8; for i=1:50 error=0; for i=1:n if w*input(i,:)'<=0 error=error+1; w=w+learnrate*input(i,:); end end if error==0 break end end X=-50:0.5:50; Y=-50:0.5:50; [X Y]=meshgrid(X,Y); Z=-(w(1)+w(2)*X+w(3)*Y)/w(4); surf(X,Y,Z);
其中绘制三维样本点函数PlotPats3D.m与make3views.m如下:
function PlotPats3D(P,D) colordef none, clf reset make3view maxx=max(P')+10 minx=min(P')-10 axis([ minx(1) maxx(1) minx(2) maxx(2) minx(3) maxx(3)]) view(72,24) [m,n]=size(P); for i=1:n if D(i)==1 plot3(P(1,i),P(2,i),P(3,i),'y+') elseif D(i)==2 plot3(P(1,i),P(2,i),P(3,i),'mo') end end
function make3view cla axis([-1 1 -1 1 -1 1]), grid on, box on, hold on xlabel('x1'), ylabel('x2'), zlabel('x3') set(gca,'CameraViewAngleMode','manual') rotate3d on colormap jet caxis([-1 1])
运行结果如下:
三维多分类问题代码:
n=6; class=4; pattern=[1 2 7 19 20 26;2 1 10 22 19 26;3 2 5 16 17 18]; Class=[1 1 2 3 4 4]; PlotPats3Dmult(pattern,Class); LearnRate=0.5; input=[ones(1,n);pattern]'; w=zeros(class,4); for i=1:class w(i,:)=[1 -1 -1 -1]; end for i=1:3000 break_time=0; for j=1:n j_class=Class(j); answer=zeros(1,class); for k=1:class answer(k)=w(k,:)*(input(j,:))'; end ret=0; for l=1:class if(l~=j_class)&&(answer(j_class)<=answer(l)) ret=1; end end if ret==1 break_time=break_time+1; for m=1:class if m~=j_class w(m,:)=w(m,:)-LearnRate*input(j,:); else w(m,:)=w(m,:)+LearnRate*input(j,:); end end end end if break_time==0 break end end newclass=class*(class-1)/2; new=zeros(newclass,4); t=0; for i=1:class for j=i+1:class t=t+1; new(t,:)=w(i,:)-w(j,:) end end for i=1:newclass X=-50:0.5:50; Y=-50:0.5:50; [X Y]=meshgrid(X,Y); Z=-(new(i,1)+new(i,2)*X+new(i,3)*Y)/new(i,4); surf(X,Y,Z); end
其中绘制样本点函数如下:
function PlotPats3Dmult(P,D) colordef none, clf reset make3view maxplot=max(P'); minplot=min(P'); axis([minplot(1)-10 maxplot(1)+10 minplot(2)-10 maxplot(2)+10 minplot(3)-10 maxplot(3)+10]) view(72,24) [m n]=size(P); for i=1:n if D(i)==1 plot3(P(1,i),P(2,i),P(3,i),'y+') elseif D(i)==2 plot3(P(1,i),P(2,i),P(3,i),'mo') elseif D(i)==3 plot3(P(1,i),P(2,i),P(3,i),'rx') else plot3(P(1,i),P(2,i),P(3,i),'bo') end end
最后运行结果如下: