gusucode.com > 用粒子滤波算法进行跟踪的matlab代码 > gmm_utilities/gmm_em.m

    function [g, fit] = gmm_em(s, g, N)
%function [g, fit] = gmm_em(s, g, N)
%
% INPUTS:
%   s - samples
%   g - initial gmm
%   N - number of iterations of EM
%
% OUTPUT:
%   g - resultant gmm
%   fit - negative log-likelihood of fit
%
% REFERENCES:
%   Figueiredo et al, On Fitting Mixture Models, 1999. Section 2.2, Equations 6 to 9.
%   Ian Nabney, NetLab, http://www.ncrg.aston.ac.uk/netlab/index.php
%
% Tim Bailey 2005.

NM = size(g.x, 2); % number of mixtures
NS = size(s, 2);   % number of data samples
g.w = g.w / sum(g.w);

while N > 0
    N = N - 1;
    
    % E-step: compute assignment likelihood
    for i=1:NM
        v = s - repvec(g.x(:,i), NS);
        w(i,:) = g.w(i) * gauss_likelihood(v, g.P(:,:,i));
    end
    
    wsr = sum(w, 1);          % sum across rows: wsr(i) = sum(w(1:NM,i)), where i in NS
    fit = -sum(log(wsr));     % negative log-likelihood of fit (from NetLab gmmem.m and gmmpost.m)
    wsr = checkzeros(wsr);    % avoid divide-by-zero
    % TODO: need to adjust also w for positions where wsr is zero (see NetLab 3.3 gmmpost.m)
    w = w ./ reprow(wsr, NM); % normalise columns: sum(w(1:NM,i)) == 1, where i in NS
    
    % M-step: compute new (x,P,w) values for gmm
    wsc = sum(w, 2);      % sum across columns: wsc(i) = sum(w(i,1:NS)), where i in NM
    g.w = wsc / sum(wsc); % note, sum(wsc) should equal NS (due to normalisation above)
    
    for i=1:NM
        w_norm = w(i,:) ./ wsc(i); % note, wsc(i) is equal to sum(w(i,:)), so sum(w_norm) = 1
        % TODO: above line has a possible divide-by-zero error, fix it.
        [g.x(:,i), g.P(:,:,i)] = sample_mean_weighted(s, w_norm);
        g.P(:,:,i) = checkP(g.P(:,:,i)); % check P has not collapsed
    end
end

%
%

% Replicate a column-vector N times
function x = repvec(x,N)
x = x(:, ones(1,N));

% Replicate a row-vector N times
function x = reprow(x,N)
x = x(ones(1,N), :);

% Check array for zero terms, change them to ones
function x = checkzeros(x)
i = find(x==0);
x(i) = 1;
% Alternatives: 
%   x = x + (x==0);
%or if ~isempty(i), x(i) = 1; end

% Check covariance for collapse, if so, inflate it
function P = checkP(P)
%if any(abs(diag(P)) < 1e-9) % check trace
if det(P) < eps % check determinant
    P = eye(size(P));
end
% TODO: improve checkP. NetLab uses measure: if min(svd(P)) < MINCOV, P=Pinit; end
% Where Pinit is the original covariance for that component.