Variational Inference入门:variational bayesian EM

在学习EM时,介绍了使用EM算法求解高斯混合模型(GMM:Gaussian Mixture Model,http://blog.csdn.net/foreseerwang/article/details/75222522),从而进行聚类的过程,并与k-means算法进行了对比,可以看到GMM模型的优势。


但是,GMM模型仍存在一些问题,譬如:必须要先知道类别数K。此时,可以通过引入GMM模型参数先验的方式,自动确定K。这就需要用到本文提到的VBEM算法了,详见PRML第10.2节和MLaPP第21.6节均有详细论述。


此外,在实践中发现,某些初始化下,简单的K-means和GMM/EM无法获得准确的聚类结果,估计可能和数据均衡度有关。请见下图,小图1是生成的GMM模型原始数据,每个高斯过程的参数与之前文章一致,但三类数据的比例是随机产生的。可以看到,在这种情况下,三类的比例失调,有一类特别少。


小图2和小图3分别是kmeans算法和GMM/EM算法聚类的结果,可以看到,与原始数据不符。


在这种情况下, 同样可以看到VBEM算法的优势。第二行的三个图是VBEM的结果。按照PRML和MLaPP书中所说,VBEM的类别数K需要选取大于实际类别数的值,这里选取K=6,小图4是直接聚6类的结果,用作VBEM的初始值。小图5是经过100次迭代的结果,仍然还有4类。小图6是迭代稳定后的结果。可以看到:

1. 自动生成了3类数据;

2. 3类的分布与原始数据非常一致。


这就是VBEM。世界正在变得越来越清晰...




代码如下:

clear all;
close all;
rng(2);

%% Parameters
N = 1000;                                   % 总数据量
D = 2;                                      % 数据维度
K = 3;                                      % 类别数目
Pz = rand([K,1]);                           % 随机生成各类比例
Pz = Pz/sum(Pz);

% 数据初始化,与之前的EM聚类程序相同
mu = [1 2; -6 2; 7 1];
sigma=zeros(K,D,D);
sigma(1,:,:)=[2 -1.5; -1.5 2];
sigma(2,:,:)=[5 -2.; -2. 3];
sigma(3,:,:)=[1 0.1; 0.1 2];

%% Data Generation and display
x = zeros(N,D);
PzCDF1 = 0;
figure(1); subplot(2,3,1); hold on;
figure(2); hold on;
for ii = 1:K,
    PzCDF2 = PzCDF1 + Pz(ii);
    PzIdx1 = round(PzCDF1*N);
    PzIdx2 = round(PzCDF2*N);
    x(PzIdx1+1:PzIdx2,:) = mvnrnd(mu(ii,:), squeeze(sigma(ii,:,:)), PzIdx2-PzIdx1);
    PzCDF1 = PzCDF2;
    
    figure(1); subplot(2,3,1); hold on;
    plot(x(PzIdx1+1:PzIdx2,1),x(PzIdx1+1:PzIdx2,2),'o');
end;
[~, tmpidx] = sort(rand(N,1));
x = x(tmpidx,:);                        % shuffle data

figure(1); subplot(2,3,1);
plot(mu(:,1),mu(:,2),'k*');
axis([-10,10,-4,8]);
title('1.Generated Data (original)', 'fontsize', 20);
xlabel('x1');
ylabel('x2');

figure(2);
plot(x(:,1),x(:,2),'o');
figure(2);
plot(mu(:,1),mu(:,2),'k*');
axis([-10,10,-4,8]);
title('Generated Data (original)');
xlabel('x1');
ylabel('x2');

fprintf('\n$$ Data generation and display completed...\n');

%% clustering: Matlab k-means
k_idx=kmeans(x,K);                  % 使用Matlab现有k-means算法
figure(1); subplot(2,3,2); hold on;
for ii=1:K,
    idx=(k_idx==ii);
    plot(x(idx,1),x(idx,2),'o');
    center = mean(x(idx,:));
    plot(center(1),center(2),'k*');
end;
axis([-10,10,-4,8]);
title('2.Clustering: Matlab k-means', 'fontsize', 20);
xlabel('x1');
ylabel('x2');

fprintf('\n$$ K-means clustering completed...\n');

%% clustering: EM
% Refer to pp.351, MLaPP
% Pw: weight
% mu: u of Gaussion distribution
% sigma: Covariance matrix of Gaussion distribution
% r(i,k): responsibility; rk: sum of r over i
% px: p(x|mu,sigma)

% 上面的聚类结果作为EM算法的初始值
Pw=zeros(K,1);
for ii=1:K,
    idx=(k_idx==ii);
    Pw(ii)=sum(idx)*1.0/N;
    mu(ii,:)=mean(x(idx,:));
    sigma(ii,:,:)=cov(x(idx,1),x(idx,2));
end;

px=zeros(N,K);
for jj=1:100, % 简单起见,直接循环,不做结束判断
    for ii=1:K,
        px(:,ii)=GaussPDF(x,mu(ii,:),squeeze(sigma(ii,:,:)));
        % 使用Matlab自带的mvnpdf,有时会出现sigma非正定的错误
    end;
    
    % E step
    temp=px.*repmat(Pw',N,1);
    r=temp./repmat(sum(temp,2),1,K);

    % M step
    rk=sum(r);
    Pw=rk'/N;
    mu=r'*x./repmat(rk',1,D);
    for ii=1:K
        sigma(ii,:,:)=x'*(repmat(r(:,ii),1,D).*x)/rk(ii)-mu(ii,:)'*mu(ii,:);
    end;
end;

% display
[~,clst_idx]=max(px,[],2);
figure(1); subplot(2,3,3); hold on;
for ii=1:K,
    idx=(clst_idx==ii);
    plot(x(idx,1),x(idx,2),'o');
    center = mean(x(idx,:));
    sigma(ii,:,:)=cov(x(idx,1),x(idx,2));
    plot(center(1),center(2),'k*');
end;

axis([-10,10,-4,8]);
title('3.Clustering: GMM/EM', 'fontsize', 20);
xlabel('x1');
ylabel('x2');

fprintf('\n$$ Gaussian Mixture using EM completed...\n');

%% Variational Bayes EM
% Refer to ch.10.2, PRML
% x: visible variable, N * D
% z: latent variable, N * K

% z: Pz, Ppi, alp0, alpk
%    Pz = P(z|pi);                                          PRML(10.37)
%    Ppi = Dir(pi|alp0)                                     PRML(10.39)
% x: Px, Pz, Ppi, mu, lambda, m0, beta0, W0, nu0
%    Px = P(x|z, mu, lambda);        高斯分布               PRML(10.38)
%    P(mu, lambda) = P(mu|lambda)*P(lambda)                PRML(10.40)
%        = N(mu|m0, (beta0*lambda)^-1) * Wi(lambda|W0, nu0)

% rho: N*K,定义参见PRML(10.46)
% r: N*K, responsibility; 归一化之后的rho,定义参见PRML(10.49)
% N_k: sum of r over n                    定义参见PRML(10.51)
% xbar_k:                                 定义参见PRML(10.52)
% S_k                                     定义参见PRML(10.53)

K = 6;                  % 增加分类数,利用VBEM自动选择分类数
k_idx=kmeans(x,K);      % 使用Matlab自带的k-means聚类,结果作为VBEM的初始值

figure(1); subplot(2,3,4); hold on;
for ii=1:K,
    idx=(k_idx==ii);
    plot(x(idx,1),x(idx,2),'o');
    center = mean(x(idx,:));
    plot(center(1),center(2),'k*');
    
    mu(ii,:) = mean(x(idx,:));
    sigma(ii,:,:)=cov(x(idx,1),x(idx,2));
    px(:,ii)=GaussPDF(x,mu(ii,:),squeeze(sigma(ii,:,:)));
    % 使用Matlab自带的mvnpdf,有时会出现sigma非正定的错误,特使用自编函数GaussPDF
end;
axis([-10,10,-4,8]);
title('4.Clustering: VBEM (initial)', 'fontsize', 20);
xlabel('x1');
ylabel('x2');

% 初始化,具体定义参见PRML式(10.40)
alp0 = 0.0001;          % alpha0,应<<1,以实现类别数自动筛选
m0 = 0;
beta0 = rand()+0.5;         % 拍脑袋初始化
W0 = squeeze(mean(sigma));
W0inv = pinv(W0);
nu0 = D*2;                  % 拍脑袋初始化

S_k = zeros(K,D,D);
W_k = zeros(K,D,D);
E_mu_lmbd = zeros(N,K);     % 即PRML中式(10.64)的等号左侧

r = px./repmat(sum(px,2),1,K);                  % N*K
N_k = ones(1,K)*(-100);
fprintf('\n');
for ii = 1:1000,
    % M-step
    N_k_new = sum(r);                           % 1*K,式(11.51)
    N_k_new(N_k_new<N/1000.0)=1e-4;             % 避免出现特别小或为零的Nk
    if sum(abs(N_k_new-N_k))<0.001,              
        break;  % early stop,如果Nk基本没变化了,则停止迭代
    else
        N_k = N_k_new;
    end;
    
    xbar_k = r'*x./repmat(N_k', 1, D);          % K*D,PRML式(10.52)
    for jj = 1:K,
        dx = x-repmat(xbar_k(jj,:), N, 1);      % N*D
        S_k(jj,:,:) = dx'*(dx.*repmat(r(:,jj),1,D))/N_k(jj); % D*D,PRML式(10.53)
    end;
    
    alp_k = alp0 + N_k;         % PRML式(10.58)
    beta_k = beta0 + N_k;       % PRML式(10.60)
    m_k = (beta0*m0 + repmat(N_k',1,D).*xbar_k)./...
        repmat(beta_k',1,D);    % K*D,PRML式(10.61)
    for jj = 1:K,
        dxm = xbar_k(jj,:)-m0;
        Wkinv = W0inv + N_k(jj)*squeeze(S_k(jj,:,:)) + ...
            dxm'*dxm*beta0*N_k(jj)/(beta0+N_k(jj));
        W_k(jj,:,:) = pinv(Wkinv);           % K*D*D,PRML式(10.62)
    end;
    nu_k = nu0 + N_k;                        % 1*K,PRML式(10.63)
    
    % E-step: 迭代计算r
    alp_tilde = sum(alp_k);
    E_ln_pi = psi(alp_k) - psi(alp_tilde);      % PRML式(10.66)
    E_ln_lambda = D*log(2)*ones(1,K);           
    for jj = 1:D,
        E_ln_lambda = E_ln_lambda + psi((nu_k+1-jj)/2); 
    end;
    for jj = 1:K,
        E_ln_lambda(jj) = E_ln_lambda(jj) + ...
            log(det(squeeze(W_k(jj,:,:))));     % PRML式(10.65)
        dxm = x-repmat(m_k(jj,:),N,1);          % N*D
        Dbeta = D/beta_k(jj);
        for nn = 1:N,
            E_mu_lmbd(nn,jj) = Dbeta+nu_k(jj)*(dxm(nn,:)*...
                squeeze(W_k(jj,:,:))*dxm(nn,:)');   % PRML式(10.64)
        end;
    end;
    
    rho = exp(repmat(E_ln_pi,N,1)+repmat(E_ln_lambda,N,1)/2-...
        E_mu_lmbd/2);                           % PRML式(10.46)
    r = rho./repmat(sum(rho,2),1,K);            % PRML式(10.49)
    
    if mod(ii,10)==0,
        fprintf('%3d loops finished.\n', ii);
    end;
    
    if ii == 100,
        [~,clst_idx]=max(r,[],2);
        figure(1); subplot(2,3,5); hold on;
        
        for kk=1:K,
            idx=(clst_idx==kk);
            if sum(idx)/N>0.01,
                plot(x(idx,1),x(idx,2),'o');
                center = mean(x(idx,:));
                plot(center(1),center(2),'k*');
            end;
        end;
        axis([-10,10,-4,8]);
        title('5.Clustering: VBEM (100 iter)', 'fontsize', 20);
        xlabel('x1');
        ylabel('x2');
    end;
end;
    
[~,clst_idx]=max(r,[],2);
figure(1); subplot(2,3,6); hold on;
Nclst = 0;
for ii=1:K,
    idx=(clst_idx==ii);
    if sum(idx)/N>0.01,
        Nclst = Nclst+1;
        plot(x(idx,1),x(idx,2),'o');
        center = mean(x(idx,:));
        plot(center(1),center(2),'k*');
    end;
end;
fprintf('\n$$ Using VBEM, totally %d clusters found.\n\n', Nclst);
axis([-10,10,-4,8]);
title('6.Clustering: VBEM (final)', 'fontsize', 20);
xlabel('x1');
ylabel('x2');


其中,使用了自己编写的高斯随机变量pdf计算函数GaussPDF,主要原因是matlab自带的mvnpdf函数有时会报sigma非正定的错误,但实际sigma是正定的。

GaussPDF代码如下:

function p = GaussPDF(x, mu, sigma)

[N, D] = size(x);

x_u = x-repmat(mu, N, 1);
p = zeros(N,1);
for ii=1:N,
    p(ii) = exp(-0.5*x_u(ii,:)*pinv(sigma)*x_u(ii,:)')/...
        sqrt(det(sigma)*(2*pi)^D);
end;

end



猜你喜欢

转载自blog.csdn.net/foreseerwang/article/details/78427430
今日推荐