fisher线性判别

Fisher线性判别(Fisher Linear Discrimination,FLD),也称线性判别式分析(Linear Discriminant Analysis, LDA)。FLD是基于样本类别进行整体特征提取的有效方法。它在使用PCA方法进行降维的基础上考虑到训练样本的类间信息。FLD的基本原理就是找到一个最合适的投影轴,使各类样本在该轴上投影之间的距离尽可能远,而每一类内的样本的投影尽可能紧凑,从而使分类效果达到最佳,即在最大化类间距离的同时最小化类内距离。FLD方法在进行图像整体特征提取方面有着广泛的应用。

在应用统计方法解决模式识别问题时,经常会遇到所谓的“维数灾难”的问题,在低维空间里适用的方法在高维空间里可能完全不适用。因此压缩特征空间的维数有时是很重要的。Fisher方法实际上涉及维数压缩的问题。如果把多维特征空间的点投影到一条直线上,就能把特征空间压缩成一维,这个在数学上是很容易办到的。但是,在高维空间里很容易分开的样品,把它们投影到任意一根直线上,有可能不同类别的样品就混在一起,无法区分,如图1(a)所示投影到xl或x2轴无法区分。若把直线绕原点转动一下,就有可能找到一个方向,样品投影到这个方向的直线上,各类样品就能很好地分开,如图1(b)所示。因此直线方向的选择很重要。一般地,总能够找到一个最好的方向,使样品投影到这个方向的直线上很容易分开。如何找到这个最好的直线方向以及如何实现向最好方向投影的变换,这正是Fisher算法要解决的基本问题,这个投影变换恰是我们所寻求的解向量w*。


图1 Fisher线性判别示意图

样品训练集以及待测样品的特征总数目为n。为了找到最佳投影方向,需要计算出各类样品均值,样品类内离散度矩阵Si和总类间离散度矩阵Sw,样品类间离散度矩阵Sb,根据Fisher准则,找到最佳投影向量,将训练集内所有样品进行投影,投影到一维Y空间,由于Y空间是一维的,则需要求出Y空间的划分边界点,找到边界点后,就可以对待测样品进行一维Y空间的投影,判断它的投影点与分界点的关系,将其归类。具体方法如下。



[cpp]  view plain  copy
  1. /****************************************************************** 
  2. *   函数名称:Fisher_2Classes(int Class0, int Class1) 
  3. *   函数类型:int  
  4. *   参数说明:Class0,Class1:0~9中的任意两个类别 
  5. *   函数功能:两类Fisher分类器,返回Class0,Class1中的一个 
  6. ******************************************************************/  
  7. int Classification::Fisher_2Classes(int Class0, int Class1)  
  8. {  
  9.     double Xmeans[2][25];//两类的均值  
  10.     double S[2][25][25];//样品类内离散度矩阵  
  11.     double Sw[25][25];//总类间离散度矩阵  
  12.     double Sw_[25][25];//Sw的逆矩阵  
  13.     double W[25];//解向量w*  
  14.     double difXmeans[25];//均值差  
  15.     double X[25];//未知样品  
  16.     double m0,m1;//类样品均值  
  17.     double y0;//阈值y0  
  18.     int i,j,k;  
  19.   
  20.     for(i=0;i<2;i++)  
  21.         for(j=0;j<25;j++)  
  22.             Xmeans[i][j]=0;  
  23.     int num0,num1;      //两类样品的个数  
  24.     //两类样品特征  
  25.     double mode0[200][25],mode1[200][25];  
  26.     //两类样品个数  
  27.     num0=40;//pattern[Class0].number;  
  28.     num1=40;//pattern[Class1].number;  
  29.     for(i=0;i<num0;i++)  
  30.     {  
  31.         for(j=0;j<25;j++)  
  32.         {  
  33.             Xmeans[0][j]+=pattern[Class0].feature[i][j];  
  34.             mode0[i][j]=pattern[Class0].feature[i][j];  
  35.         }  
  36.     }  
  37.   
  38.     for(i=0;i<num1;i++)  
  39.     {  
  40.         for(j=0;j<25;j++)  
  41.         {  
  42.             Xmeans[1][j]+=pattern[Class1].feature[i][j];      
  43.             mode1[i][j]=pattern[Class1].feature[i][j];  
  44.         }  
  45.     }  
  46.     //求得两个样品均值向量  
  47.     for(i=0;i<25;i++)      
  48.     {  
  49.         Xmeans[0][i]/=(double)num0;  
  50.         Xmeans[1][i]/=(double)num1;  
  51.     }  
  52.     //求两类样品类内离散度矩阵  
  53.     for(i=0;i<25;i++)  
  54.     for(j=0;j<25;j++)  
  55.     {  
  56.         double s0=0.0,s1=0.0;  
  57.         for(k=0;k<num0;k++)  
  58.             s0=s0+(mode0[k][i]-Xmeans[0][i])*(mode0[k][j]-Xmeans[0][j]);  
  59.         s0=s0/(double)(num0-1);  
  60.         S[0][i][j]=s0;//第一类  
  61.         for(k=0;k<num1;k++)  
  62.             s1=s1+(mode1[k][i]-Xmeans[1][i])*(mode1[k][j]-Xmeans[1][j]);  
  63.         s1=s1/(double)(num1-1);  
  64.         S[1][i][j]=s1;//第二类       
  65.     }  
  66.     //总类间离散度矩阵  
  67.     for(i=0;i<25;i++)  
  68.     for(j=0;j<25;j++)  
  69.     {  
  70.         Sw[i][j]=S[0][i][j]+S[1][i][j];  
  71.     }  
  72.     //Sw的逆矩阵  
  73.     for(i=0;i<25;i++)  
  74.         for(j=0;j<25;j++)  
  75.             Sw_[i][j]=Sw[i][j];   
  76.     double(*p)[25]=Sw_;   
  77.     brinv(*p,25);       //Sw的逆矩阵Sw_  
  78.     //计算w*  w*=Sw_×(Xmeans0-Xmeans1)  
  79.     for(i=0;i<25;i++)  
  80.         difXmeans[i]=Xmeans[0][i]-Xmeans[1][i];  
  81.     for(i=0;i<25;i++)  
  82.         W[i]=0.0;  
  83.     brmul(Sw_,difXmeans,25,W);//计算出W*  
  84.       
  85.     //各类样品均值  
  86.     m0=0.0;  
  87.     m1=0.0;  
  88.     for(i=0;i<num0;i++)  
  89.     {  
  90.         m0+=brmul(W,mode0[i],25);  
  91.     }  
  92.     for(i=0;i<num1;i++)  
  93.     {  
  94.         m1+=brmul(W,mode1[i],25);  
  95.     }  
  96.     m0/=(double)num0;  
  97.     m1/=(double)num1;  
  98.     y0=(num0*m0+num1*m1)/(num0+num1);//阈值y0  
  99.       
  100.     //对于任意的手写数字X  
  101.     for(i=0;i<25;i++)  
  102.         X[i]=testsample[i];  
  103.     double y;//X在w*上的投影点  
  104.     y=brmul(W,X,25);  
  105.     if (y>=y0)   
  106.         return Class0;  
  107.     else  
  108.         return Class1;  
  109. }  
  110.   
  111. /****************************************************************** 
  112. *   函数名称:Fisher() 
  113. *   函数类型:int  
  114. *   函数功能:Fisher分类器,返回手写数字的类别 
  115. ******************************************************************/  
  116. int Classification::Fisher()  
  117. {  
  118.     int i,j,number,maxval,num[10];  
  119.     for(i=0;i<10;i++)  
  120.         num[i]=0;  
  121.     for(i=0;i<10;i++)  
  122.         for(j=0;j<i;j++)  
  123.             num[Fisher_2Classes(i,j)]++;  
  124.     maxval=num[0];  
  125.     number=0;  
  126.     for(i=1;i<10;i++)  
  127.     {  
  128.         if(num[i]>maxval)  
  129.         {  
  130.             maxval=num[i];  
  131.             number=i;  
  132.         }  
  133.     }  
  134.     return number;  
  135. }  

[cpp]  view plain  copy
  1. /****************************************************************** 
  2. *函数名称:brmul(double a[],double b[][25],int n,double c[]) 
  3. *函数类型:void 
  4. *参数说明:a-双精度实型数组,存放A的元素。 
  5. *          b-双精度实型数组,存放B的元素。 
  6. *          n-整型变量,矩阵A的列数,也是矩阵B的行数。 
  7. *          c-双精度实型数组,存放乘积矩阵C=AB的元素。 
  8. *函数功能:求矩阵A与B的乘积矩阵C=AB。 
  9. ******************************************************************/  
  10. void brmul(double a[],double b[][25],int n,double c[])//矩阵乘法,c=a*b  
  11. {   
  12.     for(int i=0;i<n;i++)  
  13.     {  
  14.         for(int j=0;j<n;j++)  
  15.             c[i]+=a[j]*b[j][i];  
  16.     }  
  17.     return;  
  18. }  

[cpp]  view plain  copy
  1. /****************************************************************** 
  2. *函数名称:brinv(double a[],int n) 
  3. *函数类型:void 
  4. *参数说明:a--双精度实型数组,n--整型变量,方阵A的阶数 
  5. *函数功能:用全选主元Gauss-Jordan消去法求n阶实矩阵A的逆矩阵 
  6. ******************************************************************/  
  7. void brinv(double a[],int n)  
  8. {   
  9.     int *is,*js,i,j,k,l,u,v;  
  10.     double d,p;  
  11.     is=new int[n];  
  12.     js=new int[n];  
  13.     for (k=0; k<=n-1; k++)  
  14.     {   
  15.         d=0.0;  
  16.         for (i=k; i<=n-1; i++)  
  17.             for (j=k; j<=n-1; j++)  
  18.             {   
  19.                 l=i*n+j; p=fabs(a[l]);  
  20.                 if (p>d)   
  21.                 {   
  22.                     d=p; is[k]=i; js[k]=j;  
  23.                 }  
  24.             }  
  25.             if (d+1.0==1.0)  
  26.             {   
  27.                 free(is); free(js); printf("err**not inv\n");  
  28.                 return;  
  29.             }  
  30.             if (is[k]!=k)  
  31.                 for (j=0; j<=n-1; j++)  
  32.                 {   
  33.                     u=k*n+j; v=is[k]*n+j;  
  34.                     p=a[u]; a[u]=a[v]; a[v]=p;  
  35.                 }  
  36.                 if (js[k]!=k)  
  37.                     for (i=0; i<=n-1; i++)  
  38.                     {   
  39.                         u=i*n+k; v=i*n+js[k];  
  40.                         p=a[u]; a[u]=a[v]; a[v]=p;  
  41.                     }  
  42.                     l=k*n+k;  
  43.                     a[l]=1.0/a[l];  
  44.                     for (j=0; j<=n-1; j++)  
  45.                         if (j!=k)  
  46.                         {  
  47.                             u=k*n+j; a[u]=a[u]*a[l];  
  48.                         }  
  49.                         for (i=0; i<=n-1; i++)  
  50.                             if (i!=k)  
  51.                                 for (j=0; j<=n-1; j++)  
  52.                                     if (j!=k)  
  53.                                     {   
  54.                                         u=i*n+j;  
  55.                                         a[u]=a[u]-a[i*n+k]*a[k*n+j];  
  56.                                     }  
  57.                                     for (i=0; i<=n-1; i++)  
  58.                                         if (i!=k)  
  59.                                         {  
  60.                                             u=i*n+k; a[u]=-a[u]*a[l];  
  61.                                         }  
  62.     }  
  63.     for (k=n-1; k>=0; k--)  
  64.     {   
  65.         if (js[k]!=k)  
  66.             for (j=0; j<=n-1; j++)  
  67.             {   
  68.                 u=k*n+j; v=js[k]*n+j;  
  69.                 p=a[u]; a[u]=a[v]; a[v]=p;  
  70.             }  
  71.             if (is[k]!=k)  
  72.                 for (i=0; i<=n-1; i++)  
  73.                 {   
  74.                     u=i*n+k; v=i*n+is[k];  
  75.                     p=a[u]; a[u]=a[v]; a[v]=p;  
  76.                 }  
  77.     }  
  78.     delete is;   
  79.     delete js;  
  80. }  

猜你喜欢

转载自blog.csdn.net/shnu_pfh/article/details/78671002