gusucode.com > 用粒子滤波算法进行跟踪的matlab代码 > gmm_utilities/gmm_em.m
function [g, fit] = gmm_em(s, g, N) %function [g, fit] = gmm_em(s, g, N) % % INPUTS: % s - samples % g - initial gmm % N - number of iterations of EM % % OUTPUT: % g - resultant gmm % fit - negative log-likelihood of fit % % REFERENCES: % Figueiredo et al, On Fitting Mixture Models, 1999. Section 2.2, Equations 6 to 9. % Ian Nabney, NetLab, http://www.ncrg.aston.ac.uk/netlab/index.php % % Tim Bailey 2005. NM = size(g.x, 2); % number of mixtures NS = size(s, 2); % number of data samples g.w = g.w / sum(g.w); while N > 0 N = N - 1; % E-step: compute assignment likelihood for i=1:NM v = s - repvec(g.x(:,i), NS); w(i,:) = g.w(i) * gauss_likelihood(v, g.P(:,:,i)); end wsr = sum(w, 1); % sum across rows: wsr(i) = sum(w(1:NM,i)), where i in NS fit = -sum(log(wsr)); % negative log-likelihood of fit (from NetLab gmmem.m and gmmpost.m) wsr = checkzeros(wsr); % avoid divide-by-zero % TODO: need to adjust also w for positions where wsr is zero (see NetLab 3.3 gmmpost.m) w = w ./ reprow(wsr, NM); % normalise columns: sum(w(1:NM,i)) == 1, where i in NS % M-step: compute new (x,P,w) values for gmm wsc = sum(w, 2); % sum across columns: wsc(i) = sum(w(i,1:NS)), where i in NM g.w = wsc / sum(wsc); % note, sum(wsc) should equal NS (due to normalisation above) for i=1:NM w_norm = w(i,:) ./ wsc(i); % note, wsc(i) is equal to sum(w(i,:)), so sum(w_norm) = 1 % TODO: above line has a possible divide-by-zero error, fix it. [g.x(:,i), g.P(:,:,i)] = sample_mean_weighted(s, w_norm); g.P(:,:,i) = checkP(g.P(:,:,i)); % check P has not collapsed end end % % % Replicate a column-vector N times function x = repvec(x,N) x = x(:, ones(1,N)); % Replicate a row-vector N times function x = reprow(x,N) x = x(ones(1,N), :); % Check array for zero terms, change them to ones function x = checkzeros(x) i = find(x==0); x(i) = 1; % Alternatives: % x = x + (x==0); %or if ~isempty(i), x(i) = 1; end % Check covariance for collapse, if so, inflate it function P = checkP(P) %if any(abs(diag(P)) < 1e-9) % check trace if det(P) < eps % check determinant P = eye(size(P)); end % TODO: improve checkP. NetLab uses measure: if min(svd(P)) < MINCOV, P=Pinit; end % Where Pinit is the original covariance for that component.