gusucode.com > vision工具箱matlab源码程序 > vision/+vision/+internal/approximateKMeans.m

    % approximateKMeans Performs approximate K-Means clustering.
%   [centers, assignments] = approximateKMeans(features, K) clusters
%   features into K groups and returns the cluster centers and the feature
%   vector assignments to each cluster. features must be a M-by-N matrix,
%   where M is the number of features to cluster, and N is the dimension of
%   each feature vector. The output centers is a K-by-N matrix of cluster
%   centers. assignments is a 1-by-M array of cluster assignments.
%
%   [...] = approximateKMeans(...,Name,Value) specifies additional
%   name-value pair arguments described below:
%
%   'MaxIterations'   Maximum number of iterations before the K-Means
%                     algorithm is terminated.
%
%                     Default: 100
%
%   'Threshold'       When the change in the total sum of intra-cluster
%                     distances is below the Threshold, the K-Means
%                     algorithm is terminated.
%
%                     Default: 0.0001
%
%   'NumTrials'       The number of times the K-Means algorithm is run with
%                     different initial cluster centers. The solution with
%                     the lowest total sum of intra-cluster distances is
%                     returned.
%
%                     Default: 1
%
%   'Initialization'  Specify the method used to initialize the cluster
%                     centers as 'Random' or 'KMeans++'.
%
%                     Default: 'KMeans++'
%
%

% References
% ----------
% J. Philbin, O. Chum, M. Isard, J. Sivic, and A. Zisserman. Object
% retrieval with large vocabularies and fast spatial matching. InProc.
% Computer Vision and Pattern Recognition (CVPR), 2007
%
% Arthur, D. and Vassilvitskii, S. (2007). "k-means++: the advantages of
% careful seeding". Proceedings of the eighteenth annual ACM-SIAM symposium
% on Discrete algorithms. Society for Industrial and Applied Mathematics
% Philadelphia, PA, USA. pp. 1027-1035.

%   Copyright 2014 MathWorks, Inc.

function [bestCenters, bestAssignments] = approximateKMeans(features, K, varargin)

params = parseInputs(features, K, varargin{:});

printer = vision.internal.MessagePrinter.configure(params.Verbose);

N = size(features,1);

bestCompactness = inf('like', features);

printer.printMessage('vision:kmeans:numFeatures', N);
printer.printMessage('vision:kmeans:numClusters', K).linebreak;

trial = 1;

while trial <= params.NumTrials
    
    centers = initializeClusterCenters(features, K, params, printer);
    
    printTrialStartMessage(printer, trial, params);
    
    msg = printProgress(printer, '', 0, params.MaxIterations,'');
    
    [assignments, dists, isValid] = assignDataToClusters(features, centers, params);
    
    % Remove invalid features that contain Inf or NaN
    if any(~isValid)
        features = features(logical(isValid),:);
    
        if isempty(features)
            error(message('vision:kmeans:allHadInfNaN'));
        elseif size(features, 1) < K
            error(message('vision:kmeans:tooManyInfNaN', K));
        else
            warning(message('vision:kmeans:droppingInfNaN', ...
                N-size(features,1), N));
        end
        
        N = size(features,1);
        
        % Do not count this as a valid trial.
        continue
    end
    
    [centers, assignments] = updateClusterCenters(features, assignments, dists, K, params);
    
    prevCompactness = clusterCompactness(features, centers, assignments);
    prevDist        = dists;
    prevAssignments = assignments;
          
    for i = 1:params.MaxIterations
        start = tic;
        [assignments, dists] = assignDataToClusters(features, centers, params);
        
        % The approximate neighbor search can produce a worse cluster
        % assignment than the previous iteration. Keep the old assignment when
        % this happens.
        idx = (prevAssignments~=assignments) & (prevDist < dists);
        
        assignments(idx) = prevAssignments(idx);
        
        [centers, assignments] = updateClusterCenters(features, assignments, dists, K, params);
        
        % Evaluate termination criteria
        compactness = clusterCompactness(features, centers, assignments);
                
        delta = abs(prevCompactness - compactness)/(prevCompactness + eps(single(1)));
        
        elapsedTimeMessage = sprintf('(~%.2f seconds/iteration)',toc(start));
        msg = printProgress(printer, msg, i, params.MaxIterations, elapsedTimeMessage);
        
        if delta <= params.Threshold;
            printer.printMessage('vision:kmeans:trialEnd', i);
            break;
        end
        
        prevCompactness = compactness;
        prevDist        = dists;
        prevAssignments = assignments;                
    end
    
    if compactness < bestCompactness
        bestCompactness = compactness;
        bestCenters     = centers;
        bestAssignments = assignments;
    end
    
    % completed trial
    trial = trial + 1;
end

printer.linebreak;

%--------------------------------------------------------------------------
function [assignments, dists, varargout] = assignDataToClusters(features, centers, params)

% capture the rand state for each assignment. This is used in parallel
% code paths to ensure KD-Tree indexing is deterministic on all workers. 
randState = rng;

if params.UseParallel
    [assignments, dists, varargout{1:nargout-2}] = assignDataToClustersParallel(features, centers, randState);
else
    [assignments, dists, varargout{1:nargout-2}] = assignDataToClustersSerial(features, centers, randState);
end

%--------------------------------------------------------------------------
function [assignments, dists, varargout] = assignDataToClustersSerial(features, centers, randState)
if isempty(features)
    assignments = [];
    dists       = [];
    if nargout == 3
        varargout{1} = [];
    end
else
    
    searcher = vision.internal.Kdtree();
    
    % Set rand state explicity prior to indexing. This allows the rand
    % state to be provided as an input argument in parallel code paths and
    % ensure deterministic results.    
    sPrev = rng(randState);
    searcher.index(centers);
    rng(sPrev);
    
    opts.checks    = int32(32);
    opts.eps       = single(0);
    opts.grainSize = int32(10000);
    opts.tbbQueryThreshold = uint32(10000);

    [assignments, dists, varargout{1:nargout-2}] = searcher.knnSearch(features, 1, opts); % find only the closest neighbor
end
%--------------------------------------------------------------------------
function [assignments, dists, varargout] = assignDataToClustersParallel(features, centers, randState)

% outputs
assignments = [];
dists       = [];
isValid     = [];

[numFeatures, featureDim] = size(features);

% get the current parallel pool
pool = gcp();

if isempty(pool)
    [assignments, dists] = assignDataToClustersSerial(features, centers, randState);
else
    
    % Divide the work evenly amongst the workers. This helps minimize the
    % number of indexing operations.
    chunkSize = floor(numFeatures/pool.NumWorkers);
    
    % The remainder is processed in serial
    if chunkSize == 0
        remainder = numFeatures;
    else
        remainder = rem(numFeatures,chunkSize);
    end

    % features are reshaped into 3-D array to avoid data copies between
    % workers.
    featuresCube = reshape(features(1:end-remainder,:)', featureDim, chunkSize, []);
    
    parfor n = 1:size(featuresCube,3)
        
        f = reshape(featuresCube(:,:,n),featureDim,[])';
        
        [tassignments, tdists, tisValid] = assignDataToClustersSerial(f, centers, randState);
        
        assignments = [assignments tassignments];
        dists       = [dists tdists];
        isValid     = [isValid; tisValid];
    end
    
    % finish the remainder
    [tassignments, tdists, tisValid] = assignDataToClustersSerial(...
        features(end-remainder+1:end,:), centers, randState);
    
    assignments = [assignments tassignments];
    dists       = [dists tdists];
    isValid     = [isValid; tisValid];
        
    if nargout == 3
        varargout{1} = isValid;
    end

end

%--------------------------------------------------------------------------
function [centers, assignments] = updateClusterCenters(features, assignments, dists, K, params)

if params.UseParallel
    [centers, assignments] = updateClusterCentersParallel(features, assignments, dists, K);
else
    [centers, assignments] = updateClusterCentersSerial(features, assignments, dists, K);
end

%--------------------------------------------------------------------------
function [centers, assignments] = updateClusterCentersSerial(features, assignments, dists, K)

[centerSums, counts] = sumClusterFeatures(features, assignments, K);

[centerSums, assignments, counts] = reinitializeEmptyClusters(features, assignments, centerSums, counts, dists);  

centers = computeClusterCenters(centerSums, counts);

%--------------------------------------------------------------------------
% Returns updated cluster centers and cluster assignments. The cluster
% update is done in parallel by computing partial cluster summations in
% parallel and then averaging at the end.
%
%    1) Split up all the features into distinct sub-sets.
%    2) For each feature sub-set, sum the contribution of each feature to
%       its assigned cluster. And keep track of the number of features
%       belonging to each cluster. The overall cluster sum is tabulated as
%       a parallel reduction within the parfor loop.
%    3) After all sub-sets are processed in parallel, serially compute the
%       cluster centers using the cluster sums and counts.
%--------------------------------------------------------------------------
function [centers, assignments] = updateClusterCentersParallel(features, assignments, dists, K)

[numFeatures, featureDim] = size(features);

% get the current parallel pool
pool = gcp();

if isempty(pool)
    [centers, assignments] = updateClusterCentersSerial(features, assignments, dists, K);
else
    % Divide the work evenly amongst the workers. This helps minimize the
    % number of indexing operations.
    chunkSize = floor(numFeatures/pool.NumWorkers);
    
    % The remainder is processed in serial
    if chunkSize == 0
        remainder = numFeatures;
    else
        remainder = rem(numFeatures,chunkSize);
    end

    % Data is reshaped into 3-D array to avoid data copies between workers.
    assignmentsCube = reshape(assignments(1:end-remainder), 1, chunkSize, []);
    featuresCube    = reshape(features(1:end-remainder,:)', featureDim, chunkSize, []);
    
    % Process chucks of the data in parallel and compute partial cluster
    % sums. These partial sums are then averaged serially for the final
    % cluster center.
    centerSums = zeros(K, featureDim);
    counts     = zeros(K,1);
    parfor n = 1:size(featuresCube,3)
        f = reshape(featuresCube(:,:,n),featureDim,[])';                
        
        a = assignmentsCube(:,:,n);
                
        [partialSums, partialCounts] = sumClusterFeatures(f, a, K);

        centerSums = centerSums + partialSums;
        counts     = counts     + partialCounts;
    end
    
    % finish the remainder
    f = features(end-remainder+1:end,:);
    a = assignments(end-remainder+1:end);
    
    [partialSums, partialCounts] = sumClusterFeatures(f, a, K);
    
    centerSums = centerSums + partialSums;
    counts     = counts     + partialCounts;
       
    [centerSums, assignments, counts] = reinitializeEmptyClusters(features, assignments, centerSums, counts, dists);       
        
    centers = computeClusterCenters(centerSums, counts);
    
end

%--------------------------------------------------------------------------
function centers = computeClusterCenters(centerSums, counts)
K = numel(counts);
countInv = spdiags(1./(counts+eps), 0, K, K);  % reduce storage costs of K-by-K diagonal matrix
centers  = single(full(countInv * centerSums)); 

%--------------------------------------------------------------------------
function [accum, counts] = sumClusterFeatures(features, assignments, K)
% sum up features assigned to each cluster. To be used during cluster
% update.

[M, N] = size(features);

accum  = zeros(K, N); 
counts = zeros(K, 1);

% Load assignments into sparse matrix to avoid checks within the for-loop
assignmentMatrix = sparse(1:M, double(assignments), logical(assignments), M, K);

for k = 1:K
    accum(k,:) = sum(features(assignmentMatrix(:,k),:), 1, 'double'); % accumulate in double for precision.
    counts(k)  = nnz(assignmentMatrix(:,k));
end

%--------------------------------------------------------------------------
% Returns updated cluster sums, assignments and counts. For each empty
% cluster, reinitialize it using a feature that is the furthest from any
% other cluster center, taking care not to create more empty clusters in
% the process.
%--------------------------------------------------------------------------
function [centerSums, assignments, counts] = reinitializeEmptyClusters(...
    features, assignments, centerSums, counts, dists)

emptyClusterIdx = find(counts == 0);

for i = 1:numel(emptyClusterIdx)
    
    empty = emptyClusterIdx(i);
    
    clusterIsEmpty = true;
    while clusterIsEmpty
                
        [maxValue, idx] = max(dists);
        
        if maxValue == -inf
            % No alternate choices left.
            break;
        end
        
        % Prevent feature from being selected again
        dists(idx) = -inf(1,'like',dists);        
        
        % Remove feature from assigned cluster only if another empty
        % cluster is not created in the process.
        previous = assignments(idx);
        if counts(previous) > 1
            
            assignments(idx) = empty;
            
            % remove feature from it's previous cluster
            centerSums(previous, :) = centerSums(previous, :) - features(idx, :);
            counts(previous)        = counts(previous) - 1;
            
            % and move feature to empty cluster
            centerSums(empty, :) = features(idx, :);
            counts(empty)       = 1;
                                    
            clusterIsEmpty = false;
        end
    end
    
end

%--------------------------------------------------------------------------
function compactness = clusterCompactness(features, centers, assignments)

compactness = sum(sum((centers(assignments,:) - features).^2, 2));

%--------------------------------------------------------------------------
function centers = initializeClusterCenters(features, K, params, printer)

if strcmpi(params.Initialization, 'random')
    
    centers = randomClusterInit(features,K);
else
    
    centers = kmeansPlusPlusInit(features, K, printer);
end

%--------------------------------------------------------------------------
% Select cluster centers randomly.
function centers = randomClusterInit(features, K)

N = size(features, 1);

idx     = randperm(N,K);
centers = features(idx,:);

%--------------------------------------------------------------------------
% Initialize cluster centers using KMeans++.
function centers = kmeansPlusPlusInit(features, K, printer)

[M,N] = size(features);

centers       = zeros(K, N, 'like', features);
centerIndices = zeros(1,K);
minDistances  = inf(M,1,'like',features);

one = ones(1,'like', features);

% Select first center randomly
centerIndices(1) = randi(M,1);
centers(1, :)    = features(centerIndices(1), :);

printer.printMessageNoReturn('vision:kmeans:initialization');

featuresTransposed = features';

msg = '';
for k = 2:K
    % Randomly select next cluster center based on weighted distances to
    % current set of cluster centers. This biases the center selection
    % towards those centers that are furthest away from existing centers.
    
    msg = printInitProgress(printer, msg, k, K);
    
    % Compute squared distances from features to the newest cluster center
    dists = visionSSDMetric(featuresTransposed, centers(k-1,:)');
    
    % Update the minimum distances to the cluster centers
    minDistances = min(dists, minDistances);
    
    samplingWeights = bsxfun(@rdivide, minDistances, ...
        sum(minDistances) + eps(class(minDistances)));
    
    % Weighted sampling using the minimum distances as weights.
    edges = [0; cumsum(samplingWeights)];
    
    edges(end)         = one; % CDF must end at 1
    edges(edges > one) = one; % and must have all values <= 1
    
    if all(isfinite(edges))
        centerIndices(k) = discretize(rand(1), edges);
    else
        % pick a random center when edges have Infs or NaNs
        centerIndices(k) = randi(M,1);
    end
    
    centers(k, :) = features(centerIndices(k), :);
    
end
printer.print('.\n');

%--------------------------------------------------------------------------
function params = parseInputs(features, K, varargin)

if size(features, 1) < K
    error(message('vision:kmeans:numDataGTEqK'))
end

parser = inputParser();
parser.addOptional('MaxIterations', 100, @checkMaxIterations);
parser.addOptional('Threshold', single(.0001), @checkThreshold);
parser.addOptional('Initialization', 'KMeans++');
parser.addOptional('Verbose', false);
parser.addOptional('NumTrials', 1, @(x)isscalar(x) && isnumeric(x));
parser.addOptional('UseParallel', vision.internal.useParallelPreference());

parser.parse(varargin{:});

initMethod = validatestring(parser.Results.Initialization, ...
    {'Random', 'KMeans++'}, mfilename);

vision.internal.inputValidation.validateLogical(parser.Results.Verbose,'Verbose');
useParallel = vision.internal.inputValidation.validateUseParallel(parser.Results.UseParallel);

params.MaxIterations  = double (parser.Results.MaxIterations);
params.Threshold      = single (parser.Results.Threshold);
params.Initialization = initMethod;
params.Verbose        = logical(parser.Results.Verbose);
params.NumTrials      = double (parser.Results.NumTrials);
params.UseParallel    = useParallel;
params.K              = K;

%--------------------------------------------------------------------------
function checkMaxIterations(val)

validateattributes(val,{'numeric'}, ...
    {'scalar','integer','positive','real','finite'}, mfilename);

%--------------------------------------------------------------------------
function checkThreshold(val)

validateattributes(val,{'numeric'}, ...
    {'scalar','positive','real','finite'}, mfilename);

%--------------------------------------------------------------------------
function printTrialStartMessage(printer, trial, params)

if params.NumTrials > 1
    printer.printMessageNoReturn('vision:kmeans:trialStart', trial, params.NumTrials);
else
    printer.printMessageNoReturn('vision:kmeans:clustering');
end

%--------------------------------------------------------------------------
function updateMessage(printer, prevMessage, nextMessage)
backspace = sprintf(repmat('\b',1,numel(prevMessage))); % figure how much to delete
printer.print([backspace nextMessage]);

%--------------------------------------------------------------------------
function nextMessage = printInitProgress(printer, prevMessage, k, K)
nextMessage = sprintf('%.2f%%%%',100*k/K);
updateMessage(printer, prevMessage(1:end-1), nextMessage);

%--------------------------------------------------------------------------
function nextMessage = printProgress(printer, prevMessage, i, N, elapsed)
nextMessage = getString(message('vision:kmeans:clusteringProgress',i,N,elapsed));
updateMessage(printer, prevMessage, nextMessage);