上一节理论推导了LDA的实现,下面放上实现的python代码。
""" 线性判别分析步骤: 1.把来自两类w1,w2的训练样本集X分成X1与X2 2.计算各类样本的均值向量m1,m2 3.计算样本类内散度矩阵s1,s2 4.计算总的类内散度矩阵 Sw=S1+S2 5.计算Sw的逆矩阵Sw2 6.求解权向量w=Sw2(m1-m2) 7.计算 g(x)=WT(x-1/2(m1+m2)) 根据是否大于0,判断分类 """ #coding=utf-8 from numpy import * import matplotlib.pyplot as plt ###导入数据 def createDataSet(): #group1=[[0.697,0.460],[0.774,0.376],[0.634,0.264],[0.608,0.318],[0.556,0.215],[0.403,0.211],[0.481,0.149],[0.437,0.211]]#好瓜 #group2=[[0.666,0.091],[0.243,0.267],[0.245,0.057],[0.343,0.099],[0.639,0.161],[0.657,0.198],[0.360,0.370],[0.593,0.042]]#烂瓜 group1=[];group2=[] i=0 while i<10: group1.append(random.random(2)*3)#random.random(i) : 生成具有i个元素的[],每个元素取值为0到1 group2.append((random.random(2)*20)) i+=1 group1=mat(group1) group2=mat(group2) #print(group1) # print(group2) return group1,group2 ###计算样本均值 def compute_mean(group): x1=0;x2=0 for a in group: x1+=array(a)[0][0] x2+=array(a)[0][1] x1=x1/len(group) x2=x2/len(group) #m=np.mean(group,0) np.mean(group,0) # 压缩行,对各列求均值,返回一个1*n矩阵,若为1,则对各行求均值,返回m*1矩阵,若为2,则对所有数求一个均值,返回一个数 # print('平均值:',x1,x2) return mat([x1,x2]) ###计算样本类内散度矩阵 def compute_scatter(group,mean): m,d=shape(group)#m为样本个数,d为样本维度 #将所有样本向量-均值向量 group_mean=group-mean#虽然长度不匹配,但是维度匹配就可以计算 #初始化散度矩阵 s_in=mat([[0,0],[0,0]]) for i in range(m): x=mat(array(group_mean)[i]) # print('x=:',x) s_in=s_in+dot(x.T,x) ###X.T 获得矩阵X的转置,dot为矩阵乘法运算 #print('s_in',s_in) return s_in group1,group2=createDataSet() mean1=compute_mean(group1) mean2=compute_mean(group2) s_in1=compute_scatter(group1,mean1) s_in2=compute_scatter(group2,mean2) #求类内总散度 s_sum=s_in1+s_in2 print("类内散度矩阵:",s_sum) #求类内散度矩阵的逆矩阵 s_rev=s_sum.I ### X.I 返回矩阵X的逆矩阵 print('s_rev:',s_rev) #求权向量W meanW=(mean1-mean2).T print("权向量W:",meanW) w=dot(s_rev,meanW) print('w:',w) print("----------------------") # for a in group1: # distance=dot(w.T,a.T) # mean3=0.5*(mean1+mean2) # meanD=dot(w.T,mean3.T) # print(distance-meanD) # print("-----------------") # for a in group2: # distance=dot(w.T,a.T) # mean3=0.5*(mean1+mean2) # meanD=dot(w.T,mean3.T) # print(distance-meanD) #判断测试集是哪一类 xcord1=[];ycord1=[] xcord2=[];ycord2=[] #这两个用来放真实的label分组 xcord3=[];ycord3=[] xcord4=[];ycord4=[] w2=array(w) #通过LDA实现的分类图 for a in group1: item=array(a)[0] distance=dot(w.T,a.T) mean3=0.5*(mean1+mean2) meanDistance=dot(w.T,mean3.T) if((distance-meanDistance)>0): xcord1.append(item[0]) ycord1.append(item[1]) else: xcord2.append(item[0]) ycord2.append(item[1]) for a in group2: item=array(a)[0] distance=dot(w.T,a.T) mean3=0.5*(mean1+mean2) meanDistance=dot(w.T,mean3.T) if((distance-meanDistance)>0): xcord1.append(item[0]) ycord1.append(item[1]) else: xcord2.append(item[0]) ycord2.append(item[1]) #画出LDA的分布图 fig=plt.figure() ax=fig.add_subplot(111) ax.set_title("LDA") plt.xlabel('X') plt.ylabel('Y') ax.scatter(xcord1,ycord1,s=30,c='red',marker='s') ax.scatter(xcord2,ycord2,s=30,c='blue') #画出真实分类图 for a in group1: item=array(a)[0] xcord3.append(item[0]) ycord3.append(item[1]) for a in group2: item=array(a)[0] xcord4.append(item[0]) ycord4.append(item[1]) fig=plt.figure(num=3) ax=fig.add_subplot(111) ax.set_title("realLabel") plt.xlabel('X') plt.ylabel('Y') ax.scatter(xcord3,ycord3,s=30,c='red',marker='s') ax.scatter(xcord4,ycord4,s=30,c='blue') plt.show()