概率线性判别分析程序(Probabilistic Linear Discriminant Analysis)

% 自己写了一个PLDA程序,供大家交流讨论

function [Sigma,mu,F,G,H]=My_PLDA2(Xtrain,Fdim,Gdim,maxIter,tol)

% Data Normalization
I=length(Xtrain);
X0=Xtrain{1};
for i=1:I
Xtrain{i}=(Xtrain{i}-ones(size(Xtrain{i},1),1)*mean(X0))/diag(std(X0));
end
%------------------
X=[];
for i=1:I
X=[X;Xtrain{i}];
end
[N,M]=size(X);
mu=mean(X)';
J=N/I;

% reduce means
X_I=zeros(M,J,I);
S=0;
for i=1:I
Xi=Xtrain{i}';
X_I(:,:,i)=Xi-repmat(mu,[1,J]);
S=S+X_I(:,:,i)*X_I(:,:,i)';
end

% Initialization
F=randn(M,Fdim);
G=randn(M,Gdim);
Sigma=diag(S/N);

% latent variables
H=zeros(Fdim,I);
H2=zeros(Fdim,Fdim,I);
W=zeros(Gdim,J,I);
W2=zeros(Gdim,Gdim,J,I);
HW=zeros(Fdim,Gdim,J,I);

% iterations
Q=zeros(M,M,I);
U=zeros(Fdim,Fdim,I);
I_G=eye(Gdim);
I_F=eye(Fdim);

% iterations
LL = -inf(1,maxIter);
Lold=LL(1);
for iter=2:maxIter
InvS=diag(1./Sigma);
% Q
for i=1:I
Q(:,:,i)=inv(diag(Sigma)+G*G');
end

% H
for i=1:I
U(:,:,i)=pinv(I_F+J*F'*Q(:,:,i)*F);
deta=F'*Q(:,:,i)*sum(X_I(:,:,i),2);
H(:,i)=U(:,:,i)*deta;
H2(:,:,i)=H(:,i)*H(:,i)'+U(:,:,i);
end

% W,W2,HW
a1=pinv(I_G+G'*InvS*G);
a2=G'*InvS;
for i=1:I
x_i=X_I(:,:,i);
a3=x_i-repmat(F*H(:,i),[1,J]);
W(:,:,i)=a1*a2*a3;
cov_hw=-U(:,:,i)'*F'*InvS*G*a1;
cov_w=(I_G-cov_hw'*F'*InvS*G)*a1;
for j=1:J
HW(:,:,j,i)=H(:,i)*W(:,j,i)'+cov_hw;
W2(:,:,j,i)=W(:,j,i)*W(:,j,i)'+cov_w;
end
end

% F
Sum1=0;
Sum2=0;
for i=1:I
x_i=X_I(:,:,i);
h_i=repmat(H(:,i),[1,J]);
Sum1=Sum1+x_i*h_i'-G*sum(HW(:,:,:,i),3)';
Sum2=Sum2+J*H2(:,:,i);
end
F=Sum1/Sum2;

% G
Sum1=0;
Sum2=0;
for i=1:I
x_i=X_I(:,:,i);
Sum1=Sum1+x_i*W(:,:,i)'-F*sum(HW(:,:,:,i),3);
Sum2=Sum2+sum(W2(:,:,:,i),3);
end
G=Sum1/Sum2;

% Sigma
Sum1=0;
for i=1:I
X=X_I(:,:,i);
h_i=repmat(H(:,i),[1,J]);
p1=X*X'+F*J*H2(:,:,i)*F'+G*sum(W2(:,:,:,i),3)*G';
p2=2*F*sum(HW(:,:,:,i),3)*G'-2*X*h_i'*F'-2*X*W(:,:,i)'*G';
Sum1=Sum1+p1+p2;
end
Sigma=diag(Sum1/N);

% cvg
L=0;
for i=1:I
X=X_I(:,:,i)-F*repmat(H(:,i),[1,J]);
Sig=diag(Sigma)+G*G';
L=L-0.5*J*M*log(2*pi)-0.5*trace(X'*pinv(Sig)*X)-0.5*J*logdet(Sig);
end
LL(iter)=L;
cvg=abs(LL(iter)-LL(iter-1))/abs(LL(iter-1));
if cvg<tol;
break;
end
plot(LL(5:end));
drawnow;
end

猜你喜欢

转载自www.cnblogs.com/ZJU-missile/p/10498708.html