注释:代码转载自 https://sites.google.com/site/zaiyang0248/publication
一、主程序
clear all;
close all;
clc
resolution = 2;
grid = (0:resolution:180)';
SNR = 0;
M = 8;
K = 2;
N = length(grid);
T = 200;
issvd = true;
% issvd = false;
% true DOA
theta = [60.3 88.6]';
% true Phi
Phi = zeros(M,K);
for m = 1:M
for k = 1:K
Phi(m,k) = exp(1i * pi * (m-(M+1)/2) * cos(theta(k)/180*pi));
end
end
% true signal
% random signals
amp = [1 1]';
if T == 1
X = exp(1i*2*pi*unifrnd(0,1,2,1));
else
X1 = (randn(K,T) + 1i * randn(K,T)) / sqrt(2);
X = diag(amp) * X1;
end
% % correlated signals
% amp = [1 1]';
% x1 = (randn(1,200) + 1i *randn(1,200))/sqrt(2);
% y1 = (randn(1,200) + 1i *randn(1,200))/sqrt(2);
% x2 = .99*x1 + sqrt(1-.99^2)*y1;
% X = diag(amp) * [x1; x2];
% observed signal
Y = Phi * X;
sigma2 = 10^(-SNR/10) * norm(Y,'fro')^2 / (M * T);
% error
E = sqrt(sigma2)/sqrt(2)*randn(M,T) + 1i*sqrt(sigma2)/sqrt(2)*randn(M,T);
Y = Phi * X + E;
% uniform linear array (ULA), with the origin at the middle
A = zeros(M,N);
B = zeros(M,N);
for m = 1:M
for n = 1:N
temp = exp(1i * pi * (m-(M+1)/2) * cos(grid(n)/180*pi));
A(m,n) = temp;
B(m,n) = -1i * pi * (m-(M+1)/2) * sin(grid(n)/180*pi) * temp;
end
end
% initialize parameters
if issvd
[s,v,d] = svd(Y,'econ');
Y = Y * d(:, 1:K);
end
params.Y = Y;
params.A = A;
params.B = B;
params.resolution = resolution/180*pi;
params.rho = 1e-2;
params.alpha = mean(abs(A'*Y), 2);
params.beta = zeros(N,1);
params.K = K;
params.maxiter = 2000;
params.tolerance = 1e-3;
%
%
params.sigma2 = mean(var(Y))/100;
% isKnownSigmaVar = true;
% params.knownsigma2 = sigma2;
tstart = tic;
res = OGSBI(params);
time = toc(tstart);
% line plot
xp_rec = grid + res.beta * 180 / pi;
if issvd
x_rec = res.mu * d(:,1:size(res.mu,2))';
xpower_rec = mean(abs(x_rec).^2,2) + real(diag(res.Sigma)) * K / T;
else
xpower_rec = mean(abs(res.mu).^2,2) + real(diag(res.Sigma));
end
figure(1000),plot(theta, 10*log10(amp.^2), 'bo', xp_rec, 10*log10(xpower_rec), 'rx-');
axis([0,180,min([10*log10(amp.^2); 10*log10(xpower_rec)]),max([10*log10(amp.^2); 10*log10(xpower_rec)])+3]);
xlabel('DOA (degrees)', 'fontsize',12); ylabel('Power (dB)','fontsize',12);
legend('True DOAs','OGSBI spectral');
二、程序调用的函数
function res = OGSBI(paras)
% res = OGSBI(paras)
%
% OGSBI(paras) performs DOA estimation using Sparse Bayesian Inference
%
% Input:
% paras.Y: M * T matrix, sensor measurements at all snapshots
% paras.A: M * N matrix, columns are the steering vectors for different directions
% paras.B: M * N matrix, columns are derivatives of the steering vectors wrt. different directions
% paras.sigma2: initialization of noise variance
% paras.alpha: initialization of alpha
% paras.beta: initialization of beta
% paras.rho: rho
% paras.resolution: grid resolution for the directions
% paras.maxiter: maximum iteration
% paras.tol: stopping criterion
% paras.isKnownNoiseVar: true if known variance, false if unknown
% paras.K: number of sources
%
% Output:
% res.mu: mean estimation
% res.Sigma: variance estimation
% res.sigma2: estimated noise variance
% res.sigma2seq: estimated noise variance at all iterations
% res.alpha: reconstructed alpha
% res.beta: reconstructed beta
% res.iter: iteration used in the algorithm
% res.ML: maximum likelihood function value at all iterations
%
% Written by Zai Yang, 19 Jul, 2011
% reference:
% Z. Yang, L. Xie, and C. Zhang, "Off-grid direction of arrival estimation ...
% using sparse Bayesian inference", IEEE Trans. Signal Processing, ...
% vol. 61, no. 1, pp. 38--43, 2013.
eps = 1e-16;
Y = paras.Y;
A = paras.A;
B = paras.B;
[M, T] = size(Y);
N = size(A, 2);
alpha0 = 1 / paras.sigma2;
rho = paras.rho / T;
beta = paras.beta;
alpha = paras.alpha;
r = paras.resolution;
maxiter = paras.maxiter;
tol = paras.tolerance;
if isfield(paras, 'isKnownNoiseVar') && ~isempty(paras.isKnownNoiseVar)
isKnownNoiseVar = paras.isKnownNoiseVar;
else
isKnownNoiseVar = false;
end
if isKnownNoiseVar
a = 1;
b = T * M * paras.knownsigma2;
else
a = 1e-4;
b = 1e-4;
end
if isfield(paras, 'K') && ~isempty(paras.K)
K = paras.K;
else
K = min(T, M-1);
end
idx = [];
BHB = B' * B;
converged = false;
iter_beta = 1;
iter = 0;
ML = zeros(maxiter,1);
alpha0seq = zeros(maxiter,1);
while ~converged
iter = iter + 1;
Phi = A;
Phi(:,idx) = A(:,idx) + B(:,idx) * diag(beta(idx));
alpha_last = alpha;
C = 1 / alpha0 * eye(M) + Phi * diag(alpha) * Phi';
% Sigma = diag(alpha) - diag(alpha) * Phi' / C * Phi * diag(alpha);
Cinv = inv(C);
Sigma = diag(alpha) - diag(alpha) * Phi' * Cinv * Phi * diag(alpha);
mu = alpha0 * Sigma * Phi' * Y;
gamma1 = 1 - real(diag(Sigma)) ./ (alpha + eps);
% update alpha
musq = mean(abs(mu).^2, 2);
alpha = musq + real(diag(Sigma));
if rho ~= 0
alpha = -.5 / rho + sqrt(.25 / rho^2 + alpha / rho);
end
% update alpha0
resid = Y - Phi * mu;
alpha0 = (T * M + a - 1) / (norm(resid, 'fro')^2 + T / alpha0 * sum(gamma1) + b);
alpha0seq(iter) = alpha0;
% stopping criteria
if norm(alpha - alpha_last)/norm(alpha_last) < tol || iter >= maxiter
converged = true;
iter_beta = 5;
end
temp = 0;
for t = 1:T
temp = temp + real(Y(:,t)' * Cinv * Y(:,t));
end
ML(iter) = -T * real(log(det(C))) - temp + (a-1) * log(alpha0) - b * alpha0 - rho * sum(alpha);
% update beta
[temp, idx] = sort(alpha, 'descend');
idx = idx(1:K);
% [peaks, idx] = findpeaks(alpha,'sortstr','descend');
% if length(idx) > K
% idx = idx(1:K);
% end
temp = beta;
beta = zeros(N,1);
beta(idx) = temp(idx);
P = real(conj(BHB(idx,idx)) .* (mu(idx,:) * mu(idx,:)' + T * Sigma(idx,idx)));
v = zeros(length(idx), 1);
for t = 1:T
v = v + real(conj(mu(idx,t)) .* (B(:,idx)' * (Y(:,t) - A * mu(:,t))));
end
v = v - T * real(diag(B(:,idx)' * A * Sigma(:,idx)));
temp1 = P \ v;
if any(abs(temp1) > r/2) || any(diag(P) == 0)
for i = 1:iter_beta
for n = 1:K
temp_beta = beta(idx);
temp_beta(n) = 0;
beta(idx(n)) = (v(n) - P(n,:) * temp_beta) / P(n,n);
if beta(idx(n)) > r/2
beta(idx(n)) = r/2;
end
if beta(idx(n)) < -r/2
beta(idx(n)) = -r/2;
end
if P(n,n) == 0
beta(idx(n)) = 0;
end
end
end
else
beta = zeros(N,1);
beta(idx) = temp1;
end
end
res.mu = mu;
res.Sigma = Sigma;
res.beta = beta;
res.alpha = alpha;
res.iter = iter;
res.ML = ML(1:iter);
res.sigma2 = 1/alpha0;
res.sigma2seq = 1./alpha0seq(1:iter);
end