OGSBI Matlab Code(转载)

 注释:代码转载自 https://sites.google.com/site/zaiyang0248/publication

 一、主程序 

clear all;
close all;
clc

resolution = 2;
grid = (0:resolution:180)';

SNR = 0;

M = 8;
K = 2;
N = length(grid);
T = 200;

issvd = true;
% issvd = false;

% true DOA
theta = [60.3 88.6]';

% true Phi
Phi = zeros(M,K);
for m = 1:M
    for k = 1:K
        Phi(m,k) = exp(1i * pi * (m-(M+1)/2) * cos(theta(k)/180*pi));
    end
end

% true signal
% random signals
amp = [1 1]';
if T == 1
    X = exp(1i*2*pi*unifrnd(0,1,2,1));
else
    X1 = (randn(K,T) + 1i * randn(K,T)) / sqrt(2);
    X = diag(amp) * X1;
end

% % correlated signals
% amp = [1 1]';
% x1 = (randn(1,200) + 1i *randn(1,200))/sqrt(2);
% y1 = (randn(1,200) + 1i *randn(1,200))/sqrt(2);
% x2 = .99*x1 + sqrt(1-.99^2)*y1;
% X = diag(amp) * [x1; x2];

% observed signal
Y = Phi * X;
sigma2 = 10^(-SNR/10) * norm(Y,'fro')^2 / (M * T);
% error
E = sqrt(sigma2)/sqrt(2)*randn(M,T) + 1i*sqrt(sigma2)/sqrt(2)*randn(M,T);
Y = Phi * X + E;

% uniform linear array (ULA), with the origin at the middle
A = zeros(M,N);
B = zeros(M,N);
for m = 1:M
    for n = 1:N
        temp = exp(1i * pi * (m-(M+1)/2) * cos(grid(n)/180*pi));
        A(m,n) = temp;
        B(m,n) = -1i * pi * (m-(M+1)/2) * sin(grid(n)/180*pi) * temp;
    end
end

% initialize parameters
if issvd
    [s,v,d] = svd(Y,'econ');
    Y = Y * d(:, 1:K);
end

params.Y = Y;
params.A = A;
params.B = B;
params.resolution = resolution/180*pi;

params.rho = 1e-2;
params.alpha = mean(abs(A'*Y), 2);
params.beta = zeros(N,1);
params.K = K;

params.maxiter = 2000;
params.tolerance = 1e-3;
% 
% 
params.sigma2 = mean(var(Y))/100;

% isKnownSigmaVar = true;
% params.knownsigma2 = sigma2;


tstart = tic;
res = OGSBI(params);
time = toc(tstart);

% line plot
xp_rec = grid + res.beta * 180 / pi;
if issvd
    x_rec = res.mu * d(:,1:size(res.mu,2))';
    xpower_rec = mean(abs(x_rec).^2,2) + real(diag(res.Sigma)) * K / T;
else
    xpower_rec = mean(abs(res.mu).^2,2) + real(diag(res.Sigma));
end
figure(1000),plot(theta, 10*log10(amp.^2), 'bo', xp_rec, 10*log10(xpower_rec), 'rx-');
axis([0,180,min([10*log10(amp.^2); 10*log10(xpower_rec)]),max([10*log10(amp.^2); 10*log10(xpower_rec)])+3]);
xlabel('DOA (degrees)', 'fontsize',12); ylabel('Power (dB)','fontsize',12);
legend('True DOAs','OGSBI spectral');

二、程序调用的函数 

function res = OGSBI(paras)

% res = OGSBI(paras)
% 
% OGSBI(paras) performs DOA estimation using Sparse Bayesian Inference
% 
% Input:
% paras.Y: M * T matrix, sensor measurements at all snapshots
% paras.A: M * N matrix, columns are the steering vectors for different directions
% paras.B: M * N matrix, columns are derivatives of the steering vectors wrt. different directions
% paras.sigma2: initialization of noise variance
% paras.alpha: initialization of alpha
% paras.beta: initialization of beta
% paras.rho: rho
% paras.resolution: grid resolution for the directions
% paras.maxiter: maximum iteration
% paras.tol: stopping criterion
% paras.isKnownNoiseVar: true if known variance, false if unknown
% paras.K: number of sources
% 
% Output:
% res.mu: mean estimation
% res.Sigma: variance estimation
% res.sigma2: estimated noise variance
% res.sigma2seq: estimated noise variance at all iterations
% res.alpha: reconstructed alpha
% res.beta: reconstructed beta
% res.iter: iteration used in the algorithm
% res.ML: maximum likelihood function value at all iterations
% 
% Written by Zai Yang, 19 Jul, 2011
% reference: 
% Z. Yang, L. Xie, and C. Zhang, "Off-grid direction of arrival estimation ...
% using sparse Bayesian inference", IEEE Trans. Signal Processing, ...
% vol. 61, no. 1, pp. 38--43, 2013.

eps = 1e-16;

Y = paras.Y;
A = paras.A;
B = paras.B;

[M, T] = size(Y);
N = size(A, 2);

alpha0 = 1 / paras.sigma2;
rho = paras.rho / T;
beta = paras.beta;
alpha = paras.alpha;
r = paras.resolution;

maxiter = paras.maxiter;
tol = paras.tolerance;

if isfield(paras, 'isKnownNoiseVar') && ~isempty(paras.isKnownNoiseVar)
    isKnownNoiseVar = paras.isKnownNoiseVar;
else
    isKnownNoiseVar = false;
end

if isKnownNoiseVar
    a = 1;
    b = T * M * paras.knownsigma2;
else
    a = 1e-4;
    b = 1e-4;
end

if isfield(paras, 'K') && ~isempty(paras.K)
    K = paras.K;
else
    K = min(T, M-1);
end

idx = [];
BHB = B' * B;
converged = false;
iter_beta = 1;
iter = 0;
ML = zeros(maxiter,1);
alpha0seq = zeros(maxiter,1);

while ~converged
    iter = iter + 1;
    
    Phi = A;
    Phi(:,idx) = A(:,idx) + B(:,idx) * diag(beta(idx));
    
    alpha_last = alpha;
    
    C = 1 / alpha0 * eye(M) + Phi * diag(alpha) * Phi';
%     Sigma = diag(alpha) - diag(alpha) * Phi' / C * Phi * diag(alpha);
    Cinv = inv(C);
    Sigma = diag(alpha) - diag(alpha) * Phi' * Cinv * Phi * diag(alpha);
    mu = alpha0 * Sigma * Phi' * Y;
    
    
    gamma1 = 1 - real(diag(Sigma)) ./ (alpha + eps);
    
    % update alpha
    musq = mean(abs(mu).^2, 2);
    
    alpha = musq + real(diag(Sigma));
    if rho ~= 0
        alpha = -.5 / rho + sqrt(.25 / rho^2 + alpha / rho);
    end
    
    
    % update alpha0
    resid = Y - Phi * mu;
    alpha0 = (T * M + a - 1) / (norm(resid, 'fro')^2 + T / alpha0 * sum(gamma1) + b);
    alpha0seq(iter) = alpha0;
    
    % stopping criteria
    if norm(alpha - alpha_last)/norm(alpha_last) < tol || iter >= maxiter
        converged = true;
        iter_beta = 5;
    end
    
    temp = 0;
    for t = 1:T
        temp = temp + real(Y(:,t)' * Cinv * Y(:,t));
    end
    ML(iter) = -T * real(log(det(C))) - temp + (a-1) * log(alpha0) - b * alpha0 - rho * sum(alpha);

    
    % update beta
    [temp, idx] = sort(alpha, 'descend');
    idx = idx(1:K);

%     [peaks, idx] = findpeaks(alpha,'sortstr','descend');
%     if length(idx) > K
%         idx = idx(1:K);
%     end
    temp = beta;
    beta = zeros(N,1);
    beta(idx) = temp(idx);
      
    P = real(conj(BHB(idx,idx)) .* (mu(idx,:) * mu(idx,:)' + T * Sigma(idx,idx)));
    v = zeros(length(idx), 1);
    for t = 1:T
        v = v + real(conj(mu(idx,t)) .* (B(:,idx)' * (Y(:,t) - A * mu(:,t))));
    end
    v = v - T * real(diag(B(:,idx)' * A * Sigma(:,idx)));
    temp1 = P \ v;
    if any(abs(temp1) > r/2) || any(diag(P) == 0)
        for i = 1:iter_beta
            for n = 1:K
                temp_beta = beta(idx);
                temp_beta(n) = 0;
                beta(idx(n)) = (v(n) - P(n,:) * temp_beta) / P(n,n);
                if beta(idx(n)) > r/2
                    beta(idx(n)) = r/2;
                end
                if beta(idx(n)) < -r/2
                    beta(idx(n)) = -r/2;
                end
                if P(n,n) == 0
                    beta(idx(n)) = 0;
                end
            end
        end
    else
        beta = zeros(N,1);
        beta(idx) = temp1;
    end  
    
end

res.mu = mu;
res.Sigma = Sigma;
res.beta = beta;
res.alpha = alpha;
res.iter = iter;
res.ML = ML(1:iter);
res.sigma2 = 1/alpha0;
res.sigma2seq = 1./alpha0seq(1:iter);

end

猜你喜欢

转载自blog.csdn.net/qq_25634581/article/details/85682646